warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.so +0 -0
  57. warp/bin/warp.so +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/fem/field/discrete_field.py +0 -80
  257. warp/fem/space/nodal_function_space.py +0 -233
  258. warp/tests/test_all.py +0 -223
  259. warp/tests/test_array_scan.py +0 -60
  260. warp/tests/test_base.py +0 -208
  261. warp/tests/test_unresolved_func.py +0 -7
  262. warp/tests/test_unresolved_symbol.py +0 -7
  263. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  264. warp_lang-1.0.0b2.dist-info/RECORD +0 -378
  265. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  266. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  267. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  268. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  269. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/tests/test_grad.py CHANGED
@@ -5,9 +5,13 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import unittest
9
+ from typing import Any
10
+
8
11
  import numpy as np
12
+
9
13
  import warp as wp
10
- from warp.tests.test_base import *
14
+ from warp.tests.unittest_utils import *
11
15
 
12
16
  wp.init()
13
17
 
@@ -63,26 +67,26 @@ def test_for_loop_grad(test, device):
63
67
 
64
68
 
65
69
  def test_for_loop_graph_grad(test, device):
70
+ wp.load_module(device=device)
71
+
66
72
  n = 32
67
73
  val = np.ones(n, dtype=np.float32)
68
74
 
69
75
  x = wp.array(val, device=device, requires_grad=True)
70
76
  sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
71
77
 
72
- wp.force_load()
78
+ wp.capture_begin(device, force_module_load=False)
79
+ try:
80
+ tape = wp.Tape()
81
+ with tape:
82
+ wp.launch(for_loop_grad, dim=1, inputs=[n, x, sum], device=device)
73
83
 
74
- wp.capture_begin()
75
-
76
- tape = wp.Tape()
77
- with tape:
78
- wp.launch(for_loop_grad, dim=1, inputs=[n, x, sum], device=device)
79
-
80
- tape.backward(loss=sum)
81
-
82
- graph = wp.capture_end()
84
+ tape.backward(loss=sum)
85
+ finally:
86
+ graph = wp.capture_end(device)
83
87
 
84
88
  wp.capture_launch(graph)
85
- wp.synchronize()
89
+ wp.synchronize_device(device)
86
90
 
87
91
  # ensure forward pass outputs persist
88
92
  assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
@@ -90,7 +94,7 @@ def test_for_loop_graph_grad(test, device):
90
94
  assert_np_equal(x.grad.numpy(), 2.0 * val)
91
95
 
92
96
  wp.capture_launch(graph)
93
- wp.synchronize()
97
+ wp.synchronize_device(device)
94
98
 
95
99
 
96
100
  @wp.kernel
@@ -272,8 +276,7 @@ def gradcheck(func, func_name, inputs, device, eps=1e-4, tol=1e-2):
272
276
  numerical gradient computed using finite differences.
273
277
  """
274
278
 
275
- module = wp.get_module(func.__module__)
276
- kernel = wp.Kernel(func=func, key=func_name, module=module)
279
+ kernel = wp.Kernel(func=func, key=func_name)
277
280
 
278
281
  def f(xs):
279
282
  # call the kernel without taping for finite differences
@@ -316,7 +319,7 @@ def gradcheck(func, func_name, inputs, device, eps=1e-4, tol=1e-2):
316
319
 
317
320
 
318
321
  def test_vector_math_grad(test, device):
319
- np.random.seed(123)
322
+ rng = np.random.default_rng(123)
320
323
 
321
324
  # test unary operations
322
325
  for dim, vec_type in [(2, wp.vec2), (3, wp.vec3), (4, wp.vec4), (4, wp.quat)]:
@@ -332,14 +335,14 @@ def test_vector_math_grad(test, device):
332
335
 
333
336
  # run the tests with 5 different random inputs
334
337
  for _ in range(5):
335
- x = wp.array(np.random.randn(1, dim).astype(np.float32), dtype=vec_type, device=device)
338
+ x = wp.array(rng.random(size=(1, dim), dtype=np.float32), dtype=vec_type, device=device)
336
339
  gradcheck(check_length, f"check_length_{vec_type.__name__}", [x], device)
337
340
  gradcheck(check_length_sq, f"check_length_sq_{vec_type.__name__}", [x], device)
338
341
  gradcheck(check_normalize, f"check_normalize_{vec_type.__name__}", [x], device)
339
342
 
340
343
 
341
344
  def test_matrix_math_grad(test, device):
342
- np.random.seed(123)
345
+ rng = np.random.default_rng(123)
343
346
 
344
347
  # test unary operations
345
348
  for dim, mat_type in [(2, wp.mat22), (3, wp.mat33), (4, wp.mat44)]:
@@ -352,13 +355,13 @@ def test_matrix_math_grad(test, device):
352
355
 
353
356
  # run the tests with 5 different random inputs
354
357
  for _ in range(5):
355
- x = wp.array(np.random.randn(1, dim, dim).astype(np.float32), ndim=1, dtype=mat_type, device=device)
358
+ x = wp.array(rng.random(size=(1, dim, dim), dtype=np.float32), ndim=1, dtype=mat_type, device=device)
356
359
  gradcheck(check_determinant, f"check_length_{mat_type.__name__}", [x], device)
357
360
  gradcheck(check_trace, f"check_length_sq_{mat_type.__name__}", [x], device)
358
361
 
359
362
 
360
363
  def test_3d_math_grad(test, device):
361
- np.random.seed(123)
364
+ rng = np.random.default_rng(123)
362
365
 
363
366
  # test binary operations
364
367
  def check_cross(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
@@ -408,7 +411,9 @@ def test_3d_math_grad(test, device):
408
411
 
409
412
  # run the tests with 5 different random inputs
410
413
  for _ in range(5):
411
- x = wp.array(np.random.randn(2, 3).astype(np.float32), dtype=wp.vec3, device=device, requires_grad=True)
414
+ x = wp.array(
415
+ rng.standard_normal(size=(2, 3), dtype=np.float32), dtype=wp.vec3, device=device, requires_grad=True
416
+ )
412
417
  gradcheck(check_cross, "check_cross_3d", [x], device)
413
418
  gradcheck(check_dot, "check_dot_3d", [x], device)
414
419
  gradcheck(check_mat33, "check_mat33_3d", [x], device, eps=2e-2)
@@ -419,7 +424,7 @@ def test_3d_math_grad(test, device):
419
424
 
420
425
 
421
426
  def test_multi_valued_function_grad(test, device):
422
- np.random.seed(123)
427
+ rng = np.random.default_rng(123)
423
428
 
424
429
  @wp.func
425
430
  def multi_valued(x: float, y: float, z: float):
@@ -434,7 +439,9 @@ def test_multi_valued_function_grad(test, device):
434
439
 
435
440
  # run the tests with 5 different random inputs
436
441
  for _ in range(5):
437
- x = wp.array(np.random.randn(2, 3).astype(np.float32), dtype=wp.vec3, device=device, requires_grad=True)
442
+ x = wp.array(
443
+ rng.standard_normal(size=(2, 3), dtype=np.float32), dtype=wp.vec3, device=device, requires_grad=True
444
+ )
438
445
  gradcheck(check_multi_valued, "check_multi_valued_3d", [x], device)
439
446
 
440
447
 
@@ -467,19 +474,17 @@ def test_mesh_grad(test, device):
467
474
  c = mesh.points[k]
468
475
  return wp.length(wp.cross(b - a, c - a)) * 0.5
469
476
 
477
+ @wp.kernel
470
478
  def compute_area(mesh_id: wp.uint64, out: wp.array(dtype=wp.float32)):
471
479
  wp.atomic_add(out, 0, compute_triangle_area(mesh_id, wp.tid()))
472
480
 
473
- module = wp.get_module(compute_area.__module__)
474
- kernel = wp.Kernel(func=compute_area, key="compute_area", module=module)
475
-
476
481
  num_tris = int(len(indices) / 3)
477
482
 
478
483
  # compute analytical gradient
479
484
  tape = wp.Tape()
480
485
  output = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
481
486
  with tape:
482
- wp.launch(kernel, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
487
+ wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
483
488
 
484
489
  tape.backward(loss=output)
485
490
 
@@ -496,13 +501,13 @@ def test_mesh_grad(test, device):
496
501
  pos = wp.array(pos_np, dtype=wp.vec3, device=device)
497
502
  mesh = wp.Mesh(points=pos, indices=indices)
498
503
  output.zero_()
499
- wp.launch(kernel, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
504
+ wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
500
505
  f1 = output.numpy()[0]
501
506
  pos_np[i, j] -= 2 * eps
502
507
  pos = wp.array(pos_np, dtype=wp.vec3, device=device)
503
508
  mesh = wp.Mesh(points=pos, indices=indices)
504
509
  output.zero_()
505
- wp.launch(kernel, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
510
+ wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
506
511
  f2 = output.numpy()[0]
507
512
  pos_np[i, j] += eps
508
513
  fd_grad[i, j] = (f1 - f2) / (2 * eps)
@@ -510,189 +515,126 @@ def test_mesh_grad(test, device):
510
515
  assert np.allclose(ad_grad, fd_grad, atol=1e-3)
511
516
 
512
517
 
513
- # atomic add function that memorizes which thread incremented the counter
514
- # so that the correct counter value per thread can be used in the replay
515
- # phase of the backward pass
516
518
  @wp.func
517
- def reversible_increment(
518
- counter: wp.array(dtype=int),
519
- counter_index: int,
520
- value: int,
521
- thread_values: wp.array(dtype=int),
522
- tid: int
523
- ):
524
- next_index = wp.atomic_add(counter, counter_index, value)
525
- thread_values[tid] = next_index
526
- return next_index
527
-
528
-
529
- @wp.func_replay(reversible_increment)
530
- def replay_reversible_increment(
531
- counter: wp.array(dtype=int),
532
- counter_index: int,
533
- value: int,
534
- thread_values: wp.array(dtype=int),
535
- tid: int
536
- ):
537
- return thread_values[tid]
519
+ def name_clash(a: float, b: float) -> float:
520
+ return a + b
538
521
 
539
522
 
540
- def test_custom_replay_grad(test, device):
541
- num_threads = 128
542
- counter = wp.zeros(1, dtype=wp.int32, device=device)
543
- thread_ids = wp.zeros(num_threads, dtype=wp.int32, device=device)
544
- inputs = wp.array(np.arange(num_threads, dtype=np.float32), device=device, requires_grad=True)
545
- outputs = wp.zeros_like(inputs)
523
+ @wp.func_grad(name_clash)
524
+ def adj_name_clash(a: float, b: float, adj_ret: float):
525
+ # names `adj_a` and `adj_b` must not clash with function args of generated function
526
+ adj_a = 0.0
527
+ adj_b = 0.0
528
+ if a < 0.0:
529
+ adj_a = adj_ret
530
+ if b > 0.0:
531
+ adj_b = adj_ret
546
532
 
547
- @wp.kernel
548
- def run_atomic_add(
549
- input: wp.array(dtype=float),
550
- counter: wp.array(dtype=int),
551
- thread_values: wp.array(dtype=int),
552
- output: wp.array(dtype=float)
553
- ):
554
- tid = wp.tid()
555
- idx = reversible_increment(counter, 0, 1, thread_values, tid)
556
- output[idx] = input[idx] ** 2.0
533
+ wp.adjoint[a] += adj_a
534
+ wp.adjoint[b] += adj_b
557
535
 
558
- tape = wp.Tape()
559
- with tape:
560
- wp.launch(run_atomic_add, dim=num_threads, inputs=[inputs, counter, thread_ids], outputs=[outputs], device=device)
561
536
 
562
- tape.backward(grads={outputs: wp.array(np.ones(num_threads, dtype=np.float32), device=device)})
563
- assert_np_equal(inputs.grad.numpy(), 2.0 * inputs.numpy(), tol=1e-4)
537
+ @wp.kernel
538
+ def name_clash_kernel(
539
+ input_a: wp.array(dtype=float),
540
+ input_b: wp.array(dtype=float),
541
+ output: wp.array(dtype=float),
542
+ ):
543
+ tid = wp.tid()
544
+ output[tid] = name_clash(input_a[tid], input_b[tid])
564
545
 
565
546
 
566
- @wp.func
567
- def overload_fn(x: float, y: float):
568
- return x * 3.0 + y / 3.0, y ** 2.5
547
+ def test_name_clash(test, device):
548
+ # tests that no name clashes occur when variable names such as `adj_a` are used in custom gradient code
549
+ with wp.ScopedDevice(device):
550
+ input_a = wp.array([1.0, -2.0, 3.0], dtype=wp.float32, requires_grad=True)
551
+ input_b = wp.array([4.0, 5.0, -6.0], dtype=wp.float32, requires_grad=True)
552
+ output = wp.zeros(3, dtype=wp.float32, requires_grad=True)
553
+
554
+ tape = wp.Tape()
555
+ with tape:
556
+ wp.launch(name_clash_kernel, dim=len(input_a), inputs=[input_a, input_b], outputs=[output])
569
557
 
558
+ tape.backward(grads={output: wp.array(np.ones(len(input_a), dtype=np.float32))})
570
559
 
571
- @wp.func_grad(overload_fn)
572
- def overload_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float):
573
- wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0
574
- wp.adjoint[y] += y * adj_ret1 * 3.0
560
+ assert_np_equal(input_a.grad.numpy(), np.array([0.0, 1.0, 0.0]))
561
+ assert_np_equal(input_b.grad.numpy(), np.array([1.0, 1.0, 0.0]))
562
+
563
+
564
+ @wp.struct
565
+ class NestedStruct:
566
+ v: wp.vec2
575
567
 
576
568
 
577
569
  @wp.struct
578
- class MyStruct:
579
- scalar: float
580
- vec: wp.vec3
570
+ class ParentStruct:
571
+ a: float
572
+ n: NestedStruct
581
573
 
582
574
 
583
575
  @wp.func
584
- def overload_fn(x: MyStruct):
585
- return x.vec[0] * x.vec[1] * x.vec[2] * 4.0, wp.length(x.vec), x.scalar ** 0.5
576
+ def noop(a: Any):
577
+ pass
586
578
 
587
579
 
588
- @wp.func_grad(overload_fn)
589
- def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: float):
590
- wp.adjoint[x.scalar] += x.scalar * adj_ret0 * 10.0
591
- wp.adjoint[x.vec][0] += adj_ret0 * x.vec[1] * x.vec[2] * 20.0
592
- wp.adjoint[x.vec][1] += adj_ret1 * x.vec[0] * x.vec[2] * 30.0
593
- wp.adjoint[x.vec][2] += adj_ret2 * x.vec[0] * x.vec[1] * 40.0
580
+ @wp.func
581
+ def sum2(v: wp.vec2):
582
+ return v[0] + v[1]
594
583
 
595
584
 
596
585
  @wp.kernel
597
- def run_overload_float_fn(
598
- xs: wp.array(dtype=float),
599
- ys: wp.array(dtype=float),
600
- output0: wp.array(dtype=float),
601
- output1: wp.array(dtype=float)
602
- ):
603
- i = wp.tid()
604
- out0, out1 = overload_fn(xs[i], ys[i])
605
- output0[i] = out0
606
- output1[i] = out1
586
+ def test_struct_attribute_gradient_kernel(src: wp.array(dtype=float), res: wp.array(dtype=float)):
587
+ tid = wp.tid()
607
588
 
589
+ p = ParentStruct(src[tid], NestedStruct(wp.vec2(2.0 * src[tid])))
590
+
591
+ # test that we are not losing gradients when accessing attributes
592
+ noop(p.a)
593
+ noop(p.n)
594
+ noop(p.n.v)
595
+
596
+ res[tid] = p.a + sum2(p.n.v)
597
+
598
+
599
+ def test_struct_attribute_gradient(test_case, device):
600
+ src = wp.array([1], dtype=float, requires_grad=True)
601
+ res = wp.empty_like(src)
608
602
 
609
- @wp.kernel
610
- def run_overload_struct_fn(xs: wp.array(dtype=MyStruct), output: wp.array(dtype=float)):
611
- i = wp.tid()
612
- out0, out1, out2 = overload_fn(xs[i])
613
- output[i] = out0 + out1 + out2
614
-
615
-
616
- def test_custom_overload_grad(test, device):
617
- dim = 3
618
- xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True)
619
- ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True)
620
- out0_float = wp.zeros(dim)
621
- out1_float = wp.zeros(dim)
622
- tape = wp.Tape()
623
- with tape:
624
- wp.launch(
625
- run_overload_float_fn,
626
- dim=dim,
627
- inputs=[xs_float, ys_float],
628
- outputs=[out0_float, out1_float])
629
- tape.backward(grads={
630
- out0_float: wp.array(np.ones(dim), dtype=wp.float32),
631
- out1_float: wp.array(np.ones(dim), dtype=wp.float32)})
632
- assert_np_equal(xs_float.grad.numpy(), xs_float.numpy() * 42.0 + ys_float.numpy() * 10.0)
633
- assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0)
634
-
635
- x0 = MyStruct()
636
- x0.vec = wp.vec3(1., 2., 3.)
637
- x0.scalar = 4.
638
- x1 = MyStruct()
639
- x1.vec = wp.vec3(5., 6., 7.)
640
- x1.scalar = -1.0
641
- x2 = MyStruct()
642
- x2.vec = wp.vec3(8., 9., 10.)
643
- x2.scalar = 19.0
644
- xs_struct = wp.array([x0, x1, x2], dtype=MyStruct, requires_grad=True)
645
- out_struct = wp.zeros(dim)
646
603
  tape = wp.Tape()
647
604
  with tape:
648
- wp.launch(
649
- run_overload_struct_fn,
650
- dim=dim,
651
- inputs=[xs_struct],
652
- outputs=[out_struct])
653
- tape.backward(grads={out_struct: wp.array(np.ones(dim), dtype=wp.float32)})
654
- xs_struct_np = xs_struct.numpy()
655
- struct_grads = xs_struct.grad.numpy()
656
- # fmt: off
657
- assert_np_equal(
658
- np.array([g[0] for g in struct_grads]),
659
- np.array([g[0] * 10.0 for g in xs_struct_np]))
660
- assert_np_equal(
661
- np.array([g[1][0] for g in struct_grads]),
662
- np.array([g[1][1] * g[1][2] * 20.0 for g in xs_struct_np]))
663
- assert_np_equal(
664
- np.array([g[1][1] for g in struct_grads]),
665
- np.array([g[1][0] * g[1][2] * 30.0 for g in xs_struct_np]))
666
- assert_np_equal(
667
- np.array([g[1][2] for g in struct_grads]),
668
- np.array([g[1][0] * g[1][1] * 40.0 for g in xs_struct_np]))
669
- # fmt: on
605
+ wp.launch(test_struct_attribute_gradient_kernel, dim=1, inputs=[src, res])
606
+
607
+ res.grad.fill_(1.0)
608
+ tape.backward()
609
+
610
+ test_case.assertEqual(src.grad.numpy()[0], 5.0)
611
+
670
612
 
613
+ devices = get_test_devices()
671
614
 
672
- def register(parent):
673
- devices = get_test_devices()
674
615
 
675
- class TestGrad(parent):
676
- pass
616
+ class TestGrad(unittest.TestCase):
617
+ pass
677
618
 
678
- # add_function_test(TestGrad, "test_while_loop_grad", test_while_loop_grad, devices=devices)
679
- add_function_test(TestGrad, "test_for_loop_nested_for_grad", test_for_loop_nested_for_grad, devices=devices)
680
- add_function_test(TestGrad, "test_scalar_grad", test_scalar_grad, devices=devices)
681
- add_function_test(TestGrad, "test_for_loop_grad", test_for_loop_grad, devices=devices)
682
- add_function_test(TestGrad, "test_for_loop_graph_grad", test_for_loop_graph_grad, devices=wp.get_cuda_devices())
683
- add_function_test(TestGrad, "test_for_loop_nested_if_grad", test_for_loop_nested_if_grad, devices=devices)
684
- add_function_test(TestGrad, "test_preserve_outputs_grad", test_preserve_outputs_grad, devices=devices)
685
- add_function_test(TestGrad, "test_vector_math_grad", test_vector_math_grad, devices=devices)
686
- add_function_test(TestGrad, "test_matrix_math_grad", test_matrix_math_grad, devices=devices)
687
- add_function_test(TestGrad, "test_3d_math_grad", test_3d_math_grad, devices=devices)
688
- add_function_test(TestGrad, "test_multi_valued_function_grad", test_multi_valued_function_grad, devices=devices)
689
- add_function_test(TestGrad, "test_mesh_grad", test_mesh_grad, devices=devices)
690
- add_function_test(TestGrad, "test_custom_replay_grad", test_custom_replay_grad, devices=devices)
691
- add_function_test(TestGrad, "test_custom_overload_grad", test_custom_overload_grad, devices=devices)
692
619
 
693
- return TestGrad
620
+ # add_function_test(TestGrad, "test_while_loop_grad", test_while_loop_grad, devices=devices)
621
+ add_function_test(TestGrad, "test_for_loop_nested_for_grad", test_for_loop_nested_for_grad, devices=devices)
622
+ add_function_test(TestGrad, "test_scalar_grad", test_scalar_grad, devices=devices)
623
+ add_function_test(TestGrad, "test_for_loop_grad", test_for_loop_grad, devices=devices)
624
+ add_function_test(
625
+ TestGrad, "test_for_loop_graph_grad", test_for_loop_graph_grad, devices=get_unique_cuda_test_devices()
626
+ )
627
+ add_function_test(TestGrad, "test_for_loop_nested_if_grad", test_for_loop_nested_if_grad, devices=devices)
628
+ add_function_test(TestGrad, "test_preserve_outputs_grad", test_preserve_outputs_grad, devices=devices)
629
+ add_function_test(TestGrad, "test_vector_math_grad", test_vector_math_grad, devices=devices)
630
+ add_function_test(TestGrad, "test_matrix_math_grad", test_matrix_math_grad, devices=devices)
631
+ add_function_test(TestGrad, "test_3d_math_grad", test_3d_math_grad, devices=devices)
632
+ add_function_test(TestGrad, "test_multi_valued_function_grad", test_multi_valued_function_grad, devices=devices)
633
+ add_function_test(TestGrad, "test_mesh_grad", test_mesh_grad, devices=devices)
634
+ add_function_test(TestGrad, "test_name_clash", test_name_clash, devices=devices)
635
+ add_function_test(TestGrad, "test_struct_attribute_gradient", test_struct_attribute_gradient, devices=devices)
694
636
 
695
637
 
696
638
  if __name__ == "__main__":
697
- c = register(unittest.TestCase)
639
+ wp.build.clear_kernel_cache()
698
640
  unittest.main(verbosity=2, failfast=False)
@@ -0,0 +1,176 @@
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
10
+ import numpy as np
11
+
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
14
+
15
+ wp.init()
16
+
17
+
18
+ # atomic add function that memorizes which thread incremented the counter
19
+ # so that the correct counter value per thread can be used in the replay
20
+ # phase of the backward pass
21
+ @wp.func
22
+ def reversible_increment(
23
+ counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
24
+ ):
25
+ next_index = wp.atomic_add(counter, counter_index, value)
26
+ thread_values[tid] = next_index
27
+ return next_index
28
+
29
+
30
+ @wp.func_replay(reversible_increment)
31
+ def replay_reversible_increment(
32
+ counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
33
+ ):
34
+ return thread_values[tid]
35
+
36
+
37
+ def test_custom_replay_grad(test, device):
38
+ num_threads = 128
39
+ counter = wp.zeros(1, dtype=wp.int32, device=device)
40
+ thread_ids = wp.zeros(num_threads, dtype=wp.int32, device=device)
41
+ inputs = wp.array(np.arange(num_threads, dtype=np.float32), device=device, requires_grad=True)
42
+ outputs = wp.zeros_like(inputs)
43
+
44
+ @wp.kernel
45
+ def run_atomic_add(
46
+ input: wp.array(dtype=float),
47
+ counter: wp.array(dtype=int),
48
+ thread_values: wp.array(dtype=int),
49
+ output: wp.array(dtype=float),
50
+ ):
51
+ tid = wp.tid()
52
+ idx = reversible_increment(counter, 0, 1, thread_values, tid)
53
+ output[idx] = input[idx] ** 2.0
54
+
55
+ tape = wp.Tape()
56
+ with tape:
57
+ wp.launch(
58
+ run_atomic_add, dim=num_threads, inputs=[inputs, counter, thread_ids], outputs=[outputs], device=device
59
+ )
60
+
61
+ tape.backward(grads={outputs: wp.array(np.ones(num_threads, dtype=np.float32), device=device)})
62
+ assert_np_equal(inputs.grad.numpy(), 2.0 * inputs.numpy(), tol=1e-4)
63
+
64
+
65
+ @wp.func
66
+ def overload_fn(x: float, y: float):
67
+ return x * 3.0 + y / 3.0, y**2.5
68
+
69
+
70
+ @wp.func_grad(overload_fn)
71
+ def overload_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float):
72
+ wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0
73
+ wp.adjoint[y] += y * adj_ret1 * 3.0
74
+
75
+
76
+ @wp.struct
77
+ class MyStruct:
78
+ scalar: float
79
+ vec: wp.vec3
80
+
81
+
82
+ @wp.func
83
+ def overload_fn(x: MyStruct):
84
+ return x.vec[0] * x.vec[1] * x.vec[2] * 4.0, wp.length(x.vec), x.scalar**0.5
85
+
86
+
87
+ @wp.func_grad(overload_fn)
88
+ def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: float):
89
+ wp.adjoint[x.scalar] += x.scalar * adj_ret0 * 10.0
90
+ wp.adjoint[x.vec][0] += adj_ret0 * x.vec[1] * x.vec[2] * 20.0
91
+ wp.adjoint[x.vec][1] += adj_ret1 * x.vec[0] * x.vec[2] * 30.0
92
+ wp.adjoint[x.vec][2] += adj_ret2 * x.vec[0] * x.vec[1] * 40.0
93
+
94
+
95
+ @wp.kernel
96
+ def run_overload_float_fn(
97
+ xs: wp.array(dtype=float), ys: wp.array(dtype=float), output0: wp.array(dtype=float), output1: wp.array(dtype=float)
98
+ ):
99
+ i = wp.tid()
100
+ out0, out1 = overload_fn(xs[i], ys[i])
101
+ output0[i] = out0
102
+ output1[i] = out1
103
+
104
+
105
+ @wp.kernel
106
+ def run_overload_struct_fn(xs: wp.array(dtype=MyStruct), output: wp.array(dtype=float)):
107
+ i = wp.tid()
108
+ out0, out1, out2 = overload_fn(xs[i])
109
+ output[i] = out0 + out1 + out2
110
+
111
+
112
+ def test_custom_overload_grad(test, device):
113
+ dim = 3
114
+ xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True)
115
+ ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True)
116
+ out0_float = wp.zeros(dim)
117
+ out1_float = wp.zeros(dim)
118
+ tape = wp.Tape()
119
+ with tape:
120
+ wp.launch(run_overload_float_fn, dim=dim, inputs=[xs_float, ys_float], outputs=[out0_float, out1_float])
121
+ tape.backward(
122
+ grads={
123
+ out0_float: wp.array(np.ones(dim), dtype=wp.float32),
124
+ out1_float: wp.array(np.ones(dim), dtype=wp.float32),
125
+ }
126
+ )
127
+ assert_np_equal(xs_float.grad.numpy(), xs_float.numpy() * 42.0 + ys_float.numpy() * 10.0)
128
+ assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0)
129
+
130
+ x0 = MyStruct()
131
+ x0.vec = wp.vec3(1.0, 2.0, 3.0)
132
+ x0.scalar = 4.0
133
+ x1 = MyStruct()
134
+ x1.vec = wp.vec3(5.0, 6.0, 7.0)
135
+ x1.scalar = -1.0
136
+ x2 = MyStruct()
137
+ x2.vec = wp.vec3(8.0, 9.0, 10.0)
138
+ x2.scalar = 19.0
139
+ xs_struct = wp.array([x0, x1, x2], dtype=MyStruct, requires_grad=True)
140
+ out_struct = wp.zeros(dim)
141
+ tape = wp.Tape()
142
+ with tape:
143
+ wp.launch(run_overload_struct_fn, dim=dim, inputs=[xs_struct], outputs=[out_struct])
144
+ tape.backward(grads={out_struct: wp.array(np.ones(dim), dtype=wp.float32)})
145
+ xs_struct_np = xs_struct.numpy()
146
+ struct_grads = xs_struct.grad.numpy()
147
+ # fmt: off
148
+ assert_np_equal(
149
+ np.array([g[0] for g in struct_grads]),
150
+ np.array([g[0] * 10.0 for g in xs_struct_np]))
151
+ assert_np_equal(
152
+ np.array([g[1][0] for g in struct_grads]),
153
+ np.array([g[1][1] * g[1][2] * 20.0 for g in xs_struct_np]))
154
+ assert_np_equal(
155
+ np.array([g[1][1] for g in struct_grads]),
156
+ np.array([g[1][0] * g[1][2] * 30.0 for g in xs_struct_np]))
157
+ assert_np_equal(
158
+ np.array([g[1][2] for g in struct_grads]),
159
+ np.array([g[1][0] * g[1][1] * 40.0 for g in xs_struct_np]))
160
+ # fmt: on
161
+
162
+
163
+ devices = get_test_devices()
164
+
165
+
166
+ class TestGradCustoms(unittest.TestCase):
167
+ pass
168
+
169
+
170
+ add_function_test(TestGradCustoms, "test_custom_replay_grad", test_custom_replay_grad, devices=devices)
171
+ add_function_test(TestGradCustoms, "test_custom_overload_grad", test_custom_overload_grad, devices=devices)
172
+
173
+
174
+ if __name__ == "__main__":
175
+ wp.build.clear_kernel_cache()
176
+ unittest.main(verbosity=2, failfast=False)