warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
@@ -334,19 +334,19 @@ def test_constructors(test, device, dtype, register_kernels=False):
334
334
  outcomponents: wp.array(dtype=wptype),
335
335
  ):
336
336
  # multiply outputs by 2 so we've got something to backpropagate:
337
- m2result = wptype(2) * mat22(vec2(input[0], input[2]), vec2(input[1], input[3]))
338
- m3result = wptype(2) * mat33(
337
+ m2result = wptype(2) * wp.matrix_from_cols(vec2(input[0], input[2]), vec2(input[1], input[3]))
338
+ m3result = wptype(2) * wp.matrix_from_cols(
339
339
  vec3(input[4], input[7], input[10]),
340
340
  vec3(input[5], input[8], input[11]),
341
341
  vec3(input[6], input[9], input[12]),
342
342
  )
343
- m4result = wptype(2) * mat44(
343
+ m4result = wptype(2) * wp.matrix_from_cols(
344
344
  vec4(input[13], input[17], input[21], input[25]),
345
345
  vec4(input[14], input[18], input[22], input[26]),
346
346
  vec4(input[15], input[19], input[23], input[27]),
347
347
  vec4(input[16], input[20], input[24], input[28]),
348
348
  )
349
- m5result = wptype(2) * mat55(
349
+ m5result = wptype(2) * wp.matrix_from_cols(
350
350
  vec5(input[29], input[34], input[39], input[44], input[49]),
351
351
  vec5(input[30], input[35], input[40], input[45], input[50]),
352
352
  vec5(input[31], input[36], input[41], input[46], input[51]),
@@ -23,8 +23,6 @@ import numpy as np
23
23
  import warp as wp
24
24
  from warp.tests.unittest_utils import *
25
25
 
26
- wp.init() # For wp.context.runtime.core.is_cutlass_enabled()
27
-
28
26
  # kernels are defined in the global scope, to ensure wp.Kernel objects are not GC'ed in the MGPU case
29
27
  # kernel args are assigned array modes during codegen, so wp.Kernel objects generated during codegen
30
28
  # must be preserved for overwrite tracking to function
@@ -378,62 +376,6 @@ def test_copy(test, device):
378
376
  wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
379
377
 
380
378
 
381
- # wp.matmul uses wp.record_func. Ensure array modes are propagated correctly.
382
- def test_matmul(test, device):
383
- if device.is_cuda and not wp.context.runtime.core.is_cutlass_enabled():
384
- test.skipTest("Warp was not built with CUTLASS support")
385
-
386
- saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
387
- try:
388
- wp.config.verify_autograd_array_access = True
389
-
390
- a = wp.ones((3, 3), dtype=float, requires_grad=True, device=device)
391
- b = wp.ones_like(a)
392
- c = wp.ones_like(a)
393
- d = wp.zeros_like(a)
394
-
395
- tape = wp.Tape()
396
-
397
- with tape:
398
- wp.matmul(a, b, c, d)
399
-
400
- test.assertEqual(a._is_read, True)
401
- test.assertEqual(b._is_read, True)
402
- test.assertEqual(c._is_read, True)
403
- test.assertEqual(d._is_read, False)
404
-
405
- finally:
406
- wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
407
-
408
-
409
- # wp.batched_matmul uses wp.record_func. Ensure array modes are propagated correctly.
410
- def test_batched_matmul(test, device):
411
- if device.is_cuda and not wp.context.runtime.core.is_cutlass_enabled():
412
- test.skipTest("Warp was not built with CUTLASS support")
413
-
414
- saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
415
- try:
416
- wp.config.verify_autograd_array_access = True
417
-
418
- a = wp.ones((1, 3, 3), dtype=float, requires_grad=True, device=device)
419
- b = wp.ones_like(a)
420
- c = wp.ones_like(a)
421
- d = wp.zeros_like(a)
422
-
423
- tape = wp.Tape()
424
-
425
- with tape:
426
- wp.batched_matmul(a, b, c, d)
427
-
428
- test.assertEqual(a._is_read, True)
429
- test.assertEqual(b._is_read, True)
430
- test.assertEqual(c._is_read, True)
431
- test.assertEqual(d._is_read, False)
432
-
433
- finally:
434
- wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
435
-
436
-
437
379
  # write after read warning with in-place operators within a kernel
438
380
  def test_in_place_operators_warning(test, device):
439
381
  saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
@@ -593,8 +535,6 @@ add_function_test(TestOverwrite, "test_views", test_views, devices=devices)
593
535
  add_function_test(TestOverwrite, "test_reset", test_reset, devices=devices)
594
536
 
595
537
  add_function_test(TestOverwrite, "test_copy", test_copy, devices=devices)
596
- add_function_test(TestOverwrite, "test_matmul", test_matmul, devices=devices, check_output=False)
597
- add_function_test(TestOverwrite, "test_batched_matmul", test_batched_matmul, devices=devices, check_output=False)
598
538
  add_function_test(TestOverwrite, "test_atomic_operations", test_atomic_operations, devices=devices)
599
539
 
600
540
  # Some warning are only issued during codegen, and codegen only runs on cuda_0 in the MGPU case.
warp/tests/test_quat.py CHANGED
@@ -1205,7 +1205,6 @@ def test_quat_to_matrix(test, device, dtype, register_kernels=False):
1205
1205
 
1206
1206
  wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1207
1207
  quat = wp.types.quaternion(dtype=wptype)
1208
- mat3 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1209
1208
  vec3 = wp.types.vector(length=3, dtype=wptype)
1210
1209
 
1211
1210
  def check_quat_to_matrix(
@@ -1239,7 +1238,7 @@ def test_quat_to_matrix(test, device, dtype, register_kernels=False):
1239
1238
  wptype(1),
1240
1239
  ),
1241
1240
  )
1242
- result_manual = mat3(xaxis, yaxis, zaxis)
1241
+ result_manual = wp.matrix_from_cols(xaxis, yaxis, zaxis)
1243
1242
 
1244
1243
  idx = 0
1245
1244
  for i in range(3):
@@ -1711,18 +1710,31 @@ def test_quat_rpy_grad(test, device, dtype, register_kernels=False):
1711
1710
  def test_quat_from_matrix(test, device, dtype, register_kernels=False):
1712
1711
  wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1713
1712
  mat33 = wp.types.matrix((3, 3), wptype)
1713
+ mat44 = wp.types.matrix((4, 4), wptype)
1714
1714
  quat = wp.types.quaternion(wptype)
1715
1715
 
1716
1716
  def quat_from_matrix(m: wp.array2d(dtype=wptype), loss: wp.array(dtype=wptype), idx: int):
1717
1717
  tid = wp.tid()
1718
1718
 
1719
- matrix = mat33(
1720
- m[tid, 0], m[tid, 1], m[tid, 2], m[tid, 3], m[tid, 4], m[tid, 5], m[tid, 6], m[tid, 7], m[tid, 8]
1719
+ # fmt: off
1720
+ m3 = mat33(
1721
+ m[tid, 0], m[tid, 1], m[tid, 2],
1722
+ m[tid, 3], m[tid, 4], m[tid, 5],
1723
+ m[tid, 6], m[tid, 7], m[tid, 8],
1721
1724
  )
1725
+ q1 = wp.quat_from_matrix(m3)
1722
1726
 
1723
- q = wp.quat_from_matrix(matrix)
1727
+ m4 = mat44(
1728
+ m[tid, 0], m[tid, 1], m[tid, 2], wptype(0.0),
1729
+ m[tid, 3], m[tid, 4], m[tid, 5], wptype(0.0),
1730
+ m[tid, 6], m[tid, 7], m[tid, 8], wptype(0.0),
1731
+ wptype(0.0), wptype(0.0), wptype(0.0), wptype(1.0),
1732
+ )
1733
+ q2 = wp.quat_from_matrix(m4)
1734
+ # fmt: on
1724
1735
 
1725
- wp.atomic_add(loss, 0, q[idx])
1736
+ wp.expect_eq(q1, q2)
1737
+ wp.atomic_add(loss, 0, q1[idx])
1726
1738
 
1727
1739
  def quat_from_matrix_forward(mats: wp.array2d(dtype=wptype), loss: wp.array(dtype=wptype), idx: int):
1728
1740
  tid = wp.tid()
@@ -1891,113 +1903,6 @@ def test_quat_identity(test, device, dtype, register_kernels=False):
1891
1903
  assert_np_equal(output.numpy(), expected)
1892
1904
 
1893
1905
 
1894
- ############################################################
1895
-
1896
-
1897
- def test_quat_assign(test, device, dtype, register_kernels=False):
1898
- np_type = np.dtype(dtype)
1899
- wp_type = wp.types.np_dtype_to_warp_type[np_type]
1900
-
1901
- quat = wp.types.quaternion(dtype=wp_type)
1902
-
1903
- def quattest_read_write_store(x: wp.array(dtype=wp_type), a: wp.array(dtype=quat)):
1904
- tid = wp.tid()
1905
-
1906
- t = a[tid]
1907
- t[0] = x[tid]
1908
- a[tid] = t
1909
-
1910
- def quattest_in_register(x: wp.array(dtype=wp_type), a: wp.array(dtype=quat)):
1911
- tid = wp.tid()
1912
-
1913
- g = wp_type(0.0)
1914
- q = a[tid]
1915
- g = q[0] + wp_type(2.0) * q[1] + wp_type(3.0) * q[2] + wp_type(4.0) * q[3]
1916
- x[tid] = g
1917
-
1918
- def quattest_in_register_overwrite(x: wp.array(dtype=quat), a: wp.array(dtype=quat)):
1919
- tid = wp.tid()
1920
-
1921
- f = quat()
1922
- a_quat = a[tid]
1923
- f = a_quat
1924
- f[1] = wp_type(3.0)
1925
-
1926
- x[tid] = f
1927
-
1928
- def quattest_component(x: wp.array(dtype=quat), y: wp.array(dtype=wp_type)):
1929
- i = wp.tid()
1930
-
1931
- a = quat()
1932
- a.x = wp_type(1.0) * y[i]
1933
- a.y = wp_type(2.0) * y[i]
1934
- a.z = wp_type(3.0) * y[i]
1935
- a.w = wp_type(4.0) * y[i]
1936
- x[i] = a
1937
-
1938
- kernel_read_write_store = getkernel(quattest_read_write_store, suffix=dtype.__name__)
1939
- kernel_in_register = getkernel(quattest_in_register, suffix=dtype.__name__)
1940
- kernel_in_register_overwrite = getkernel(quattest_in_register_overwrite, suffix=dtype.__name__)
1941
- kernel_component = getkernel(quattest_component, suffix=dtype.__name__)
1942
-
1943
- if register_kernels:
1944
- return
1945
-
1946
- a = wp.ones(1, dtype=quat, device=device, requires_grad=True)
1947
- x = wp.full(1, value=2.0, dtype=wp_type, device=device, requires_grad=True)
1948
-
1949
- tape = wp.Tape()
1950
- with tape:
1951
- wp.launch(kernel_read_write_store, dim=1, inputs=[x, a], device=device)
1952
-
1953
- tape.backward(grads={a: wp.ones_like(a, requires_grad=False)})
1954
-
1955
- assert_np_equal(a.numpy(), np.array([[2.0, 1.0, 1.0, 1.0]], dtype=np_type))
1956
- assert_np_equal(x.grad.numpy(), np.array([1.0], dtype=np_type))
1957
-
1958
- tape.reset()
1959
-
1960
- a = wp.ones(1, dtype=quat, device=device, requires_grad=True)
1961
- x = wp.zeros(1, dtype=wp_type, device=device, requires_grad=True)
1962
-
1963
- with tape:
1964
- wp.launch(kernel_in_register, dim=1, inputs=[x, a], device=device)
1965
-
1966
- tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1967
-
1968
- assert_np_equal(x.numpy(), np.array([10.0], dtype=np_type))
1969
- assert_np_equal(a.grad.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np_type))
1970
-
1971
- tape.reset()
1972
-
1973
- x = wp.zeros(1, dtype=quat, requires_grad=True)
1974
- y = wp.ones(1, dtype=wp_type, requires_grad=True)
1975
-
1976
- tape = wp.Tape()
1977
- with tape:
1978
- wp.launch(kernel_component, dim=1, inputs=[x, y])
1979
-
1980
- tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1981
-
1982
- assert_np_equal(x.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np_type))
1983
- assert_np_equal(y.grad.numpy(), np.array([10.0], dtype=np_type))
1984
-
1985
- x = wp.zeros(1, dtype=quat, device=device, requires_grad=True)
1986
- a = wp.ones(1, dtype=quat, device=device, requires_grad=True)
1987
-
1988
- tape = wp.Tape()
1989
- with tape:
1990
- wp.launch(kernel_in_register_overwrite, dim=1, inputs=[x, a], device=device)
1991
-
1992
- tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1993
-
1994
- assert_np_equal(x.numpy(), np.array([[1.0, 3.0, 1.0, 1.0]], dtype=np_type))
1995
- assert_np_equal(a.grad.numpy(), np.array([[1.0, 0.0, 1.0, 1.0]], dtype=np_type))
1996
-
1997
-
1998
- ############################################################
1999
-
2000
-
2001
1906
  def test_quat_euler_conversion(test, device, dtype, register_kernels=False):
2002
1907
  rng = np.random.default_rng(123)
2003
1908
  N = 3
@@ -2077,6 +1982,12 @@ def test_constructor_default():
2077
1982
  wp.expect_eq(qeye[2], 0.0)
2078
1983
  wp.expect_eq(qeye[3], 1.0)
2079
1984
 
1985
+ qlit = wp.quaternion(1.0, 2.0, 3.0, 4.0, dtype=float)
1986
+ wp.expect_eq(qlit[0], 1.0)
1987
+ wp.expect_eq(qlit[1], 2.0)
1988
+ wp.expect_eq(qlit[2], 3.0)
1989
+ wp.expect_eq(qlit[3], 4.0)
1990
+
2080
1991
 
2081
1992
  def test_py_arithmetic_ops(test, device, dtype):
2082
1993
  wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
@@ -2128,56 +2039,348 @@ def test_quat_len(test, device):
2128
2039
 
2129
2040
 
2130
2041
  @wp.kernel
2131
- def vector_augassign_kernel(
2132
- a: wp.array(dtype=wp.quat), b: wp.array(dtype=wp.quat), c: wp.array(dtype=wp.quat), d: wp.array(dtype=wp.quat)
2133
- ):
2042
+ def quat_extract_subscript(x: wp.array(dtype=wp.quat), y: wp.array(dtype=float)):
2043
+ tid = wp.tid()
2044
+
2045
+ a = x[tid]
2046
+ b = a[0] + 2.0 * a[1] + 3.0 * a[2] + 4.0 * a[3]
2047
+ y[tid] = b
2048
+
2049
+
2050
+ """ TODO: rhs attribute indexing
2051
+ @wp.kernel
2052
+ def quat_extract_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=float)):
2053
+ tid = wp.tid()
2054
+
2055
+ a = x[tid]
2056
+ b = a.x + float(2.0) * a.y + 3.0 * a.z + 4.0 * a.w
2057
+ y[tid] = b
2058
+ """
2059
+
2060
+
2061
+ def test_quat_extract(test, device):
2062
+ def run(kernel):
2063
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2064
+ y = wp.zeros(1, dtype=float, requires_grad=True, device=device)
2065
+
2066
+ tape = wp.Tape()
2067
+ with tape:
2068
+ wp.launch(kernel, 1, inputs=[x], outputs=[y], device=device)
2069
+
2070
+ y.grad = wp.ones_like(y)
2071
+ tape.backward()
2072
+
2073
+ assert_np_equal(y.numpy(), np.array([10.0], dtype=float))
2074
+ assert_np_equal(x.grad.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=float))
2075
+
2076
+ run(quat_extract_subscript)
2077
+ # run(quat_extract_attribute)
2078
+
2079
+
2080
+ @wp.kernel
2081
+ def quat_assign_subscript(x: wp.array(dtype=float), y: wp.array(dtype=wp.quat)):
2082
+ i = wp.tid()
2083
+
2084
+ a = wp.quat()
2085
+ a[0] = 1.0 * x[i]
2086
+ a[1] = 2.0 * x[i]
2087
+ a[2] = 3.0 * x[i]
2088
+ a[3] = 4.0 * x[i]
2089
+ y[i] = a
2090
+
2091
+
2092
+ @wp.kernel
2093
+ def quat_assign_attribute(x: wp.array(dtype=float), y: wp.array(dtype=wp.quat)):
2134
2094
  i = wp.tid()
2135
2095
 
2136
- q1 = wp.quat()
2137
- q2 = b[i]
2096
+ a = wp.quat()
2097
+ a.x = 1.0 * x[i]
2098
+ a.y = 2.0 * x[i]
2099
+ a.z = 3.0 * x[i]
2100
+ a.w = 4.0 * x[i]
2101
+ y[i] = a
2102
+
2138
2103
 
2139
- q1[0] += q2[0]
2140
- q1[1] += q2[1]
2141
- q1[2] += q2[2]
2142
- q1[3] += q2[3]
2104
+ def test_quat_assign(test, device):
2105
+ def run(kernel):
2106
+ x = wp.ones(1, dtype=float, requires_grad=True, device=device)
2107
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2143
2108
 
2144
- a[i] = q1
2109
+ tape = wp.Tape()
2110
+ with tape:
2111
+ wp.launch(kernel, 1, inputs=[x], outputs=[y], device=device)
2145
2112
 
2146
- q3 = wp.quat()
2147
- q4 = d[i]
2113
+ y.grad = wp.ones_like(y)
2114
+ tape.backward()
2148
2115
 
2149
- q3[0] += q4[0]
2150
- q3[1] += q4[1]
2151
- q3[2] += q4[2]
2152
- q3[3] += q4[3]
2116
+ assert_np_equal(y.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=float))
2117
+ assert_np_equal(x.grad.numpy(), np.array([10.0], dtype=float))
2153
2118
 
2154
- c[i] = q1
2119
+ run(quat_assign_subscript)
2120
+ run(quat_assign_attribute)
2155
2121
 
2156
2122
 
2157
- def test_vector_augassign(test, device):
2158
- N = 3
2123
+ def test_quat_assign_copy(test, device):
2124
+ saved_enable_vector_component_overwrites_setting = wp.config.enable_vector_component_overwrites
2125
+ try:
2126
+ wp.config.enable_vector_component_overwrites = True
2127
+
2128
+ @wp.kernel
2129
+ def quat_assign_overwrite(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2130
+ tid = wp.tid()
2131
+
2132
+ a = wp.quat()
2133
+ b = x[tid]
2134
+ a = b
2135
+ a[1] = 3.0
2136
+
2137
+ y[tid] = a
2138
+
2139
+ x = wp.ones(1, dtype=wp.quat, device=device, requires_grad=True)
2140
+ y = wp.zeros(1, dtype=wp.quat, device=device, requires_grad=True)
2141
+
2142
+ tape = wp.Tape()
2143
+ with tape:
2144
+ wp.launch(quat_assign_overwrite, dim=1, inputs=[x, y], device=device)
2145
+
2146
+ y.grad = wp.ones_like(y, requires_grad=False)
2147
+ tape.backward()
2148
+
2149
+ assert_np_equal(y.numpy(), np.array([[1.0, 3.0, 1.0, 1.0]], dtype=float))
2150
+ assert_np_equal(x.grad.numpy(), np.array([[1.0, 0.0, 1.0, 1.0]], dtype=float))
2151
+
2152
+ finally:
2153
+ wp.config.enable_vector_component_overwrites = saved_enable_vector_component_overwrites_setting
2154
+
2159
2155
 
2160
- a = wp.zeros(N, dtype=wp.quat, requires_grad=True)
2161
- b = wp.ones(N, dtype=wp.quat, requires_grad=True)
2156
+ @wp.kernel
2157
+ def quat_array_extract_subscript(x: wp.array2d(dtype=wp.quat), y: wp.array2d(dtype=float)):
2158
+ i, j = wp.tid()
2159
+ a = x[i, j][0]
2160
+ b = x[i, j][1]
2161
+ c = x[i, j][2]
2162
+ d = x[i, j][3]
2163
+ y[i, j] = 1.0 * a + 2.0 * b + 3.0 * c + 4.0 * d
2164
+
2165
+
2166
+ """ TODO: rhs attribute indexing
2167
+ @wp.kernel
2168
+ def quat_array_extract_attribute(x: wp.array2d(dtype=wp.quat), y: wp.array2d(dtype=float)):
2169
+ i, j = wp.tid()
2170
+ a = x[i, j].x
2171
+ b = x[i, j].y
2172
+ c = x[i, j].z
2173
+ d = x[i, j].w
2174
+ y[i, j] = 1.0 * a + 2.0 * b + 3.0 * c + 4.0 * d
2175
+ """
2176
+
2177
+
2178
+ def test_quat_array_extract(test, device):
2179
+ def run(kernel):
2180
+ x = wp.ones((1, 1), dtype=wp.quat, requires_grad=True, device=device)
2181
+ y = wp.zeros((1, 1), dtype=float, requires_grad=True, device=device)
2182
+
2183
+ tape = wp.Tape()
2184
+ with tape:
2185
+ wp.launch(kernel, (1, 1), inputs=[x], outputs=[y], device=device)
2186
+
2187
+ y.grad = wp.ones_like(y)
2188
+ tape.backward()
2189
+
2190
+ assert_np_equal(y.numpy(), np.array([[10.0]], dtype=float))
2191
+ assert_np_equal(x.grad.numpy(), np.array([[[1.0, 2.0, 3.0, 4.0]]], dtype=float))
2192
+
2193
+ run(quat_array_extract_subscript)
2194
+ # run(quat_array_extract_attribute)
2195
+
2196
+
2197
+ @wp.kernel
2198
+ def quat_array_assign_subscript(x: wp.array2d(dtype=float), y: wp.array2d(dtype=wp.quat)):
2199
+ i, j = wp.tid()
2200
+
2201
+ y[i, j][0] = 1.0 * x[i, j]
2202
+ y[i, j][1] = 2.0 * x[i, j]
2203
+ y[i, j][2] = 3.0 * x[i, j]
2204
+ y[i, j][3] = 4.0 * x[i, j]
2205
+
2206
+
2207
+ @wp.kernel
2208
+ def quat_array_assign_attribute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=wp.quat)):
2209
+ i, j = wp.tid()
2210
+
2211
+ y[i, j].x = 1.0 * x[i, j]
2212
+ y[i, j].y = 2.0 * x[i, j]
2213
+ y[i, j].z = 3.0 * x[i, j]
2214
+ y[i, j].w = 4.0 * x[i, j]
2215
+
2216
+
2217
+ def test_quat_array_assign(test, device):
2218
+ def run(kernel):
2219
+ x = wp.ones((1, 1), dtype=float, requires_grad=True, device=device)
2220
+ y = wp.zeros((1, 1), dtype=wp.quat, requires_grad=True, device=device)
2221
+
2222
+ tape = wp.Tape()
2223
+ with tape:
2224
+ wp.launch(kernel, (1, 1), inputs=[x], outputs=[y], device=device)
2225
+
2226
+ y.grad = wp.ones_like(y)
2227
+ tape.backward()
2228
+
2229
+ assert_np_equal(y.numpy(), np.array([[[1.0, 2.0, 3.0, 4.0]]], dtype=float))
2230
+ # TODO: gradient propagation for in-place array assignment
2231
+ # assert_np_equal(x.grad.numpy(), np.array([[10.0]], dtype=float))
2232
+
2233
+ run(quat_array_assign_subscript)
2234
+ run(quat_array_assign_attribute)
2235
+
2236
+
2237
+ @wp.kernel
2238
+ def quat_add_inplace_subscript(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2239
+ i = wp.tid()
2240
+
2241
+ a = wp.quat()
2242
+ b = x[i]
2243
+
2244
+ a[0] += 1.0 * b[0]
2245
+ a[1] += 2.0 * b[1]
2246
+ a[2] += 3.0 * b[2]
2247
+ a[3] += 4.0 * b[3]
2248
+
2249
+ y[i] = a
2250
+
2251
+
2252
+ """ TODO: rhs attribute indexing
2253
+ @wp.kernel
2254
+ def quat_add_inplace_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2255
+ i = wp.tid()
2256
+
2257
+ a = wp.quat()
2258
+ b = x[i]
2259
+
2260
+ a.x += 1.0 * b.x
2261
+ a.y += 2.0 * b.y
2262
+ a.z += 3.0 * b.z
2263
+ a.w += 4.0 * b.w
2264
+
2265
+ y[i] = a
2266
+ """
2267
+
2268
+
2269
+ def test_quat_add_inplace(test, device):
2270
+ def run(kernel):
2271
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2272
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2273
+
2274
+ tape = wp.Tape()
2275
+ with tape:
2276
+ wp.launch(kernel, 1, inputs=[x], outputs=[y], device=device)
2162
2277
 
2163
- c = wp.zeros(N, dtype=wp.quat, requires_grad=True)
2164
- d = wp.ones(N, dtype=wp.quat, requires_grad=True)
2278
+ y.grad = wp.ones_like(y)
2279
+ tape.backward()
2280
+
2281
+ assert_np_equal(y.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=float))
2282
+ assert_np_equal(x.grad.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=float))
2283
+
2284
+ run(quat_add_inplace_subscript)
2285
+ # run(quat_add_inplace_attribute)
2286
+
2287
+
2288
+ @wp.kernel
2289
+ def quat_sub_inplace_subscript(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2290
+ i = wp.tid()
2291
+
2292
+ a = wp.quat()
2293
+ b = x[i]
2294
+
2295
+ a[0] -= 1.0 * b[0]
2296
+ a[1] -= 2.0 * b[1]
2297
+ a[2] -= 3.0 * b[2]
2298
+ a[3] -= 4.0 * b[3]
2299
+
2300
+ y[i] = a
2301
+
2302
+
2303
+ """ TODO: rhs attribute indexing
2304
+ @wp.kernel
2305
+ def quat_sub_inplace_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2306
+ i = wp.tid()
2307
+
2308
+ a = wp.quat()
2309
+ b = x[i]
2310
+
2311
+ a.x -= 1.0 * b.x
2312
+ a.y -= 2.0 * b.y
2313
+ a.z -= 3.0 * b.z
2314
+ a.w -= 4.0 * b.w
2315
+
2316
+ y[i] = a
2317
+ """
2318
+
2319
+
2320
+ def test_quat_sub_inplace(test, device):
2321
+ def run(kernel):
2322
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2323
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2324
+
2325
+ tape = wp.Tape()
2326
+ with tape:
2327
+ wp.launch(kernel, 1, inputs=[x], outputs=[y], device=device)
2328
+
2329
+ y.grad = wp.ones_like(y)
2330
+ tape.backward()
2331
+
2332
+ assert_np_equal(y.numpy(), np.array([[-1.0, -2.0, -3.0, -4.0]], dtype=float))
2333
+ assert_np_equal(x.grad.numpy(), np.array([[-1.0, -2.0, -3.0, -4.0]], dtype=float))
2334
+
2335
+ run(quat_sub_inplace_subscript)
2336
+ # run(quat_sub_inplace_attribute)
2337
+
2338
+
2339
+ @wp.kernel
2340
+ def quat_array_add_inplace(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2341
+ i = wp.tid()
2342
+
2343
+ y[i] += x[i]
2344
+
2345
+
2346
+ def test_quat_array_add_inplace(test, device):
2347
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2348
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2165
2349
 
2166
2350
  tape = wp.Tape()
2167
2351
  with tape:
2168
- wp.launch(vector_augassign_kernel, N, inputs=[a, b, c, d])
2352
+ wp.launch(quat_array_add_inplace, 1, inputs=[x], outputs=[y], device=device)
2353
+
2354
+ y.grad = wp.ones_like(y)
2355
+ tape.backward()
2169
2356
 
2170
- tape.backward(grads={a: wp.ones_like(a), c: wp.ones_like(c)})
2357
+ assert_np_equal(y.numpy(), np.array([[1.0, 1.0, 1.0, 1.0]], dtype=float))
2358
+ assert_np_equal(x.grad.numpy(), np.array([[1.0, 1.0, 1.0, 1.0]], dtype=float))
2171
2359
 
2172
- assert_np_equal(a.numpy(), wp.ones_like(a).numpy())
2173
- assert_np_equal(a.grad.numpy(), wp.ones_like(a).numpy())
2174
- assert_np_equal(b.grad.numpy(), wp.ones_like(a).numpy())
2175
2360
 
2176
- assert_np_equal(c.numpy(), -wp.ones_like(c).numpy())
2177
- assert_np_equal(c.grad.numpy(), wp.ones_like(c).numpy())
2178
- assert_np_equal(d.grad.numpy(), -wp.ones_like(d).numpy())
2361
+ """ TODO: quat negation operator
2362
+ @wp.kernel
2363
+ def quat_array_sub_inplace(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2364
+ i = wp.tid()
2365
+
2366
+ y[i] -= x[i]
2179
2367
 
2180
2368
 
2369
+ def test_quat_array_sub_inplace(test, device):
2370
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2371
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2372
+
2373
+ tape = wp.Tape()
2374
+ with tape:
2375
+ wp.launch(quat_array_sub_inplace, 1, inputs=[x], outputs=[y], device=device)
2376
+
2377
+ y.grad = wp.ones_like(y)
2378
+ tape.backward()
2379
+
2380
+ assert_np_equal(y.numpy(), np.array([[-1.0, -1.0, -1.0, -1.0]], dtype=float))
2381
+ assert_np_equal(x.grad.numpy(), np.array([[-1.0, -1.0, -1.0, -1.0]], dtype=float))
2382
+ """
2383
+
2181
2384
  devices = get_test_devices()
2182
2385
 
2183
2386
 
@@ -2275,18 +2478,20 @@ for dtype in np_float_types:
2275
2478
  devices=devices,
2276
2479
  dtype=dtype,
2277
2480
  )
2278
- add_function_test_register_kernel(
2279
- TestQuat,
2280
- f"test_quat_assign_{dtype.__name__}",
2281
- test_quat_assign,
2282
- devices=devices,
2283
- dtype=dtype,
2284
- )
2285
2481
  add_function_test(
2286
2482
  TestQuat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2287
2483
  )
2288
2484
 
2289
2485
  add_function_test(TestQuat, "test_quat_len", test_quat_len, devices=devices)
2486
+ add_function_test(TestQuat, "test_quat_extract", test_quat_extract, devices=devices)
2487
+ add_function_test(TestQuat, "test_quat_assign", test_quat_assign, devices=devices)
2488
+ add_function_test(TestQuat, "test_quat_assign_copy", test_quat_assign_copy, devices=devices)
2489
+ add_function_test(TestQuat, "test_quat_array_extract", test_quat_array_extract, devices=devices)
2490
+ add_function_test(TestQuat, "test_quat_array_assign", test_quat_array_assign, devices=devices)
2491
+ add_function_test(TestQuat, "test_quat_add_inplace", test_quat_add_inplace, devices=devices)
2492
+ add_function_test(TestQuat, "test_quat_sub_inplace", test_quat_sub_inplace, devices=devices)
2493
+ add_function_test(TestQuat, "test_quat_array_add_inplace", test_quat_array_add_inplace, devices=devices)
2494
+ # add_function_test(TestQuat, "test_quat_array_sub_inplace", test_quat_array_sub_inplace, devices=devices)
2290
2495
 
2291
2496
 
2292
2497
  if __name__ == "__main__":