warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/tests/test_mat.py CHANGED
@@ -50,6 +50,22 @@ def get_select_kernel(dtype):
50
50
  return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
51
51
 
52
52
 
53
+ def test_shape_mismatch(test, device):
54
+ test.assertNotEqual(wp.mat33f(0.0), wp.mat22f(0.0))
55
+ test.assertNotEqual(wp.mat22f(0.0), wp.mat33f(0.0))
56
+
57
+ @wp.kernel
58
+ def kernel():
59
+ wp.expect_neq(wp.mat33f(0.0), wp.mat22f(0.0))
60
+ wp.expect_neq(wp.mat22f(0.0), wp.mat33f(0.0))
61
+
62
+ with test.assertRaisesRegex(
63
+ RuntimeError,
64
+ r"Can't test equality for objects with different types$",
65
+ ):
66
+ wp.launch(kernel, dim=1, inputs=[], device=device)
67
+
68
+
53
69
  def test_anon_constructor_error_shape_arg_missing(test, device):
54
70
  @wp.kernel
55
71
  def kernel():
@@ -127,30 +143,6 @@ def test_tpl_constructor_error_incompatible_sizes(test, device):
127
143
  wp.launch(kernel, dim=1, inputs=[], device=device)
128
144
 
129
145
 
130
- def test_tpl_constructor_error_invalid_vector_count(test, device):
131
- @wp.kernel
132
- def kernel():
133
- wp.mat33(wp.vec3(1.0, 2.0, 3.0), wp.vec3(1.0, 2.0, 3.0))
134
-
135
- with test.assertRaisesRegex(
136
- RuntimeError,
137
- r"incompatible number of column vectors given \(2\) when constructing a matrix of shape \(3, 3\)$",
138
- ):
139
- wp.launch(kernel, dim=1, inputs=[], device=device)
140
-
141
-
142
- def test_tpl_constructor_error_invalid_vector_shape(test, device):
143
- @wp.kernel
144
- def kernel():
145
- wp.mat22(wp.vec3(1.0, 2.0, 3.0), wp.vec3(4.0, 5.0, 6.0))
146
-
147
- with test.assertRaisesRegex(
148
- RuntimeError,
149
- r"incompatible column vector lengths given when constructing a matrix of shape \(2, 2\)$",
150
- ):
151
- wp.launch(kernel, dim=1, inputs=[], device=device)
152
-
153
-
154
146
  def test_tpl_constructor_error_invalid_arg_count(test, device):
155
147
  @wp.kernel
156
148
  def kernel():
@@ -234,7 +226,7 @@ def test_quat_constructor(test, device, dtype, register_kernels=False):
234
226
  c0 = s[0][0] * R[0]
235
227
  c1 = s[0][1] * R[1]
236
228
  c2 = s[0][2] * R[2]
237
- m_alt = mat44(
229
+ m_alt = wp.matrix_from_cols(
238
230
  vec4(c0[0], c0[1], c0[2], wptype(0.0)),
239
231
  vec4(c1[0], c1[1], c1[2], wptype(0.0)),
240
232
  vec4(c2[0], c2[1], c2[2], wptype(0.0)),
@@ -1066,6 +1058,124 @@ def test_svd(test, device, dtype, register_kernels=False):
1066
1058
  assert_np_equal((plusval - minusval) / (2 * dx), m3grads[ii, jj], tol=fdtol)
1067
1059
 
1068
1060
 
1061
+ def test_svd_2D(test, device, dtype, register_kernels=False):
1062
+ rng = np.random.default_rng(123)
1063
+
1064
+ tol = {
1065
+ np.float16: 1.0e-3,
1066
+ np.float32: 1.0e-6,
1067
+ np.float64: 1.0e-12,
1068
+ }.get(dtype, 0)
1069
+
1070
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1071
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1072
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1073
+
1074
+ def check_mat_svd2(
1075
+ m2: wp.array(dtype=mat22),
1076
+ Uout: wp.array(dtype=mat22),
1077
+ sigmaout: wp.array(dtype=vec2),
1078
+ Vout: wp.array(dtype=mat22),
1079
+ outcomponents: wp.array(dtype=wptype),
1080
+ ):
1081
+ U = mat22()
1082
+ sigma = vec2()
1083
+ V = mat22()
1084
+
1085
+ wp.svd2(m2[0], U, sigma, V) # Assuming there's a 2D SVD kernel
1086
+
1087
+ Uout[0] = U
1088
+ sigmaout[0] = sigma
1089
+ Vout[0] = V
1090
+
1091
+ # multiply outputs by 2 so we've got something to backpropagate:
1092
+ idx = 0
1093
+ for i in range(2):
1094
+ for j in range(2):
1095
+ outcomponents[idx] = wptype(2) * U[i, j]
1096
+ idx = idx + 1
1097
+
1098
+ for i in range(2):
1099
+ outcomponents[idx] = wptype(2) * sigma[i]
1100
+ idx = idx + 1
1101
+
1102
+ for i in range(2):
1103
+ for j in range(2):
1104
+ outcomponents[idx] = wptype(2) * V[i, j]
1105
+ idx = idx + 1
1106
+
1107
+ kernel = getkernel(check_mat_svd2, suffix=dtype.__name__)
1108
+
1109
+ output_select_kernel = get_select_kernel(wptype)
1110
+
1111
+ if register_kernels:
1112
+ return
1113
+
1114
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype) + np.eye(2), dtype=mat22, requires_grad=True, device=device)
1115
+
1116
+ outcomponents = wp.zeros(2 * 2 * 2 + 2, dtype=wptype, requires_grad=True, device=device)
1117
+ Uout = wp.zeros(1, dtype=mat22, requires_grad=True, device=device)
1118
+ sigmaout = wp.zeros(1, dtype=vec2, requires_grad=True, device=device)
1119
+ Vout = wp.zeros(1, dtype=mat22, requires_grad=True, device=device)
1120
+
1121
+ wp.launch(kernel, dim=1, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
1122
+
1123
+ Uout_np = Uout.numpy()[0].astype(np.float64)
1124
+ sigmaout_np = np.diag(sigmaout.numpy()[0].astype(np.float64))
1125
+ Vout_np = Vout.numpy()[0].astype(np.float64)
1126
+
1127
+ assert_np_equal(
1128
+ np.matmul(Uout_np, np.matmul(sigmaout_np, Vout_np.T)), m2.numpy()[0].astype(np.float64), tol=30 * tol
1129
+ )
1130
+
1131
+ if dtype == np.float16:
1132
+ # Skip gradient check for float16 due to rounding errors
1133
+ return
1134
+
1135
+ # Check gradients:
1136
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1137
+ idx = 0
1138
+ for idx in range(2 * 2 + 2 + 2 * 2):
1139
+ tape = wp.Tape()
1140
+ with tape:
1141
+ wp.launch(kernel, dim=1, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
1142
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1143
+ tape.backward(out)
1144
+ m2grads = 1.0 * tape.gradients[m2].numpy()[0]
1145
+
1146
+ tape.zero()
1147
+
1148
+ dx = 0.0001
1149
+ fdtol = 5.0e-4 if dtype == np.float64 else 2.0e-2
1150
+ for ii in range(2):
1151
+ for jj in range(2):
1152
+ m2test = 1.0 * m2.numpy()
1153
+ m2test[0, ii, jj] += dx
1154
+ wp.launch(
1155
+ kernel,
1156
+ dim=1,
1157
+ inputs=[wp.array(m2test, dtype=mat22, device=device)],
1158
+ outputs=[Uout, sigmaout, Vout, outcomponents],
1159
+ device=device,
1160
+ )
1161
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1162
+ plusval = out.numpy()[0]
1163
+
1164
+ m2test = 1.0 * m2.numpy()
1165
+ m2test[0, ii, jj] -= dx
1166
+ wp.launch(
1167
+ kernel,
1168
+ dim=1,
1169
+ inputs=[wp.array(m2test, dtype=mat22, device=device)],
1170
+ outputs=[Uout, sigmaout, Vout, outcomponents],
1171
+ device=device,
1172
+ )
1173
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1174
+ minusval = out.numpy()[0]
1175
+
1176
+ assert_np_equal((plusval - minusval) / (2 * dx), m2grads[ii, jj], tol=fdtol)
1177
+
1178
+
1069
1179
  def test_qr(test, device, dtype, register_kernels=False):
1070
1180
  rng = np.random.default_rng(123)
1071
1181
 
@@ -1513,83 +1623,6 @@ def test_transform_vector(test, device, dtype, register_kernels=False):
1513
1623
  tape.zero()
1514
1624
 
1515
1625
 
1516
- def test_mat_array_type_indexing(test, device, dtype, register_kernels=False):
1517
- np_type = np.dtype(dtype)
1518
- wp_type = wp.types.np_dtype_to_warp_type[np_type]
1519
-
1520
- vec2 = wp.types.vector(length=2, dtype=wp_type)
1521
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wp_type)
1522
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wp_type)
1523
-
1524
- def mattest_read_write_store(x: wp.array(dtype=wp_type), a: wp.array(dtype=mat22)):
1525
- tid = wp.tid()
1526
-
1527
- t = a[tid]
1528
- t[0, 0] = x[tid]
1529
- a[tid] = t
1530
-
1531
- def mattest_in_register(x: wp.array2d(dtype=mat22), y: wp.array(dtype=vec2)):
1532
- i, j = wp.tid()
1533
-
1534
- a = mat22(wp_type(0.0))
1535
- a[0] = y[i]
1536
- a[1, 1] = wp_type(3.0)
1537
- x[i, j] = a
1538
-
1539
- def mattest_in_register_overwrite(x: wp.array2d(dtype=mat22), y: wp.array(dtype=vec2)):
1540
- i, j = wp.tid()
1541
-
1542
- a = mat22(wp_type(0.0))
1543
- a[0] = y[i]
1544
- a[0, 1] = wp_type(3.0)
1545
- x[i, j] = a
1546
-
1547
- kernel_read_write_store = getkernel(mattest_read_write_store, suffix=dtype.__name__)
1548
- kernel_in_register = getkernel(mattest_in_register, suffix=dtype.__name__)
1549
- kernel_in_register_overwrite = getkernel(mattest_in_register_overwrite, suffix=dtype.__name__)
1550
-
1551
- if register_kernels:
1552
- return
1553
-
1554
- a = wp.ones(1, dtype=mat22, device=device, requires_grad=True)
1555
- x = wp.full(1, value=2.0, dtype=wp_type, device=device, requires_grad=True)
1556
-
1557
- tape = wp.Tape()
1558
- with tape:
1559
- wp.launch(kernel_read_write_store, dim=1, inputs=[x, a], device=device)
1560
-
1561
- tape.backward(grads={a: wp.ones_like(a, requires_grad=False)})
1562
-
1563
- assert_np_equal(a.numpy(), np.array([[[2.0, 1.0], [1.0, 1.0]]], dtype=np_type))
1564
- assert_np_equal(x.grad.numpy(), np.array([1.0], dtype=np_type))
1565
-
1566
- tape.reset()
1567
-
1568
- x = wp.zeros((1, 1), dtype=mat22, device=device, requires_grad=True)
1569
- y = wp.ones(1, dtype=vec2, device=device, requires_grad=True)
1570
-
1571
- with tape:
1572
- wp.launch(kernel_in_register, dim=(1, 1), inputs=[x, y], device=device)
1573
-
1574
- tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1575
-
1576
- assert_np_equal(x.numpy(), np.array([[[[1.0, 1.0], [0.0, 3.0]]]], dtype=np_type))
1577
- assert_np_equal(y.grad.numpy(), np.array([[1.0, 1.0]], dtype=np_type))
1578
-
1579
- tape.reset()
1580
-
1581
- x = wp.zeros((1, 1), dtype=mat22, device=device, requires_grad=True)
1582
- y = wp.ones(1, dtype=vec2, device=device, requires_grad=True)
1583
-
1584
- with tape:
1585
- wp.launch(kernel_in_register_overwrite, dim=(1, 1), inputs=[x, y], device=device)
1586
-
1587
- tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1588
-
1589
- assert_np_equal(x.numpy(), np.array([[[[1.0, 3.0], [0.0, 0.0]]]], dtype=np_type))
1590
- assert_np_equal(y.grad.numpy(), np.array([[1.0, 0.0]], dtype=np_type))
1591
-
1592
-
1593
1626
  # Test matrix constructors using explicit type (float16)
1594
1627
  # note that these tests are specifically not using generics / closure
1595
1628
  # args to create kernels dynamically (like the rest of this file)
@@ -1623,10 +1656,62 @@ def test_matrix_constructor_value_func():
1623
1656
  c = mat32d()
1624
1657
  d = mat32d(c, shape=(3, 2))
1625
1658
  e = mat32d(wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0))
1626
- f = mat32d(
1627
- wp.vec3d(wp.float64(1.0), wp.float64(2.0), wp.float64(3.0)),
1628
- wp.vec3d(wp.float64(1.0), wp.float64(2.0), wp.float64(3.0)),
1659
+ f = wp.matrix(1.0, 2.0, 3.0, 4.0, shape=(2, 2), dtype=float)
1660
+
1661
+
1662
+ @wp.kernel
1663
+ def test_matrix_from_vecs():
1664
+ m1 = wp.matrix_from_cols(
1665
+ wp.vec3(1.0, 2.0, 3.0),
1666
+ wp.vec3(4.0, 5.0, 6.0),
1667
+ wp.vec3(7.0, 8.0, 9.0),
1668
+ )
1669
+ wp.expect_eq(m1[0, 0], 1.0)
1670
+ wp.expect_eq(m1[0, 1], 4.0)
1671
+ wp.expect_eq(m1[0, 2], 7.0)
1672
+ wp.expect_eq(m1[1, 0], 2.0)
1673
+ wp.expect_eq(m1[1, 1], 5.0)
1674
+ wp.expect_eq(m1[1, 2], 8.0)
1675
+ wp.expect_eq(m1[2, 0], 3.0)
1676
+ wp.expect_eq(m1[2, 1], 6.0)
1677
+ wp.expect_eq(m1[2, 2], 9.0)
1678
+
1679
+ m2 = wp.matrix_from_rows(
1680
+ wp.vec3(1.0, 2.0, 3.0),
1681
+ wp.vec3(4.0, 5.0, 6.0),
1682
+ wp.vec3(7.0, 8.0, 9.0),
1683
+ )
1684
+ wp.expect_eq(m2[0, 0], 1.0)
1685
+ wp.expect_eq(m2[0, 1], 2.0)
1686
+ wp.expect_eq(m2[0, 2], 3.0)
1687
+ wp.expect_eq(m2[1, 0], 4.0)
1688
+ wp.expect_eq(m2[1, 1], 5.0)
1689
+ wp.expect_eq(m2[1, 2], 6.0)
1690
+ wp.expect_eq(m2[2, 0], 7.0)
1691
+ wp.expect_eq(m2[2, 1], 8.0)
1692
+ wp.expect_eq(m2[2, 2], 9.0)
1693
+
1694
+ m3 = wp.matrix_from_cols(
1695
+ wp.vec3(1.0, 2.0, 3.0),
1696
+ wp.vec3(4.0, 5.0, 6.0),
1629
1697
  )
1698
+ wp.expect_eq(m3[0, 0], 1.0)
1699
+ wp.expect_eq(m3[0, 1], 4.0)
1700
+ wp.expect_eq(m3[1, 0], 2.0)
1701
+ wp.expect_eq(m3[1, 1], 5.0)
1702
+ wp.expect_eq(m3[2, 0], 3.0)
1703
+ wp.expect_eq(m3[2, 1], 6.0)
1704
+
1705
+ m4 = wp.matrix_from_rows(
1706
+ wp.vec3(1.0, 2.0, 3.0),
1707
+ wp.vec3(4.0, 5.0, 6.0),
1708
+ )
1709
+ wp.expect_eq(m4[0, 0], 1.0)
1710
+ wp.expect_eq(m4[0, 1], 2.0)
1711
+ wp.expect_eq(m4[0, 2], 3.0)
1712
+ wp.expect_eq(m4[1, 0], 4.0)
1713
+ wp.expect_eq(m4[1, 1], 5.0)
1714
+ wp.expect_eq(m4[1, 2], 6.0)
1630
1715
 
1631
1716
 
1632
1717
  # Same as above but with a default (float/int) type
@@ -1742,54 +1827,389 @@ def test_matrix_len(test, device):
1742
1827
 
1743
1828
 
1744
1829
  @wp.kernel
1745
- def matrix_augassign_kernel(
1746
- a: wp.array(dtype=wp.mat22), b: wp.array(dtype=wp.mat22), c: wp.array(dtype=wp.mat22), d: wp.array(dtype=wp.mat22)
1747
- ):
1830
+ def mat_extract_element(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=float)):
1831
+ tid = wp.tid()
1832
+
1833
+ a = x[tid]
1834
+ b = a[0, 0] + 2.0 * a[0, 1] + 3.0 * a[1, 0] + 4.0 * a[1, 1]
1835
+ y[tid] = b
1836
+
1837
+
1838
+ @wp.kernel
1839
+ def mat_extract_row(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.vec2)):
1840
+ tid = wp.tid()
1841
+
1842
+ a = x[tid]
1843
+ b = a[0] + 2.0 * a[1]
1844
+ y[tid] = b
1845
+
1846
+
1847
+ def test_mat_extract(test, device):
1848
+ # matrix element
1849
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
1850
+ y = wp.zeros(1, dtype=float, requires_grad=True, device=device)
1851
+
1852
+ tape = wp.Tape()
1853
+ with tape:
1854
+ wp.launch(mat_extract_element, 1, inputs=[x], outputs=[y], device=device)
1855
+
1856
+ y.grad = wp.ones_like(y)
1857
+ tape.backward()
1858
+
1859
+ assert_np_equal(y.numpy(), np.array([10.0], dtype=float))
1860
+ assert_np_equal(x.grad.numpy(), np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=float))
1861
+
1862
+ # matrix row
1863
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
1864
+ y = wp.zeros(1, dtype=wp.vec2, requires_grad=True, device=device)
1865
+
1866
+ tape = wp.Tape()
1867
+ with tape:
1868
+ wp.launch(mat_extract_row, 1, inputs=[x], outputs=[y], device=device)
1869
+
1870
+ y.grad = wp.ones_like(y)
1871
+ tape.backward()
1872
+
1873
+ assert_np_equal(y.numpy(), np.array([[3.0, 3.0]], dtype=float))
1874
+ assert_np_equal(x.grad.numpy(), np.array([[[1.0, 1.0], [2.0, 2.0]]], dtype=float))
1875
+
1876
+
1877
+ @wp.kernel
1878
+ def mat_assign_element(x: wp.array(dtype=float), y: wp.array(dtype=wp.mat22)):
1748
1879
  i = wp.tid()
1749
1880
 
1750
- m1 = wp.mat22()
1751
- m2 = b[i]
1881
+ a = wp.mat22()
1882
+ a[0, 0] = 1.0 * x[i]
1883
+ a[0, 1] = 2.0 * x[i]
1884
+ a[1, 0] = 3.0 * x[i]
1885
+ a[1, 1] = 4.0 * x[i]
1886
+
1887
+ y[i] = a
1888
+
1889
+
1890
+ @wp.kernel
1891
+ def mat_assign_row(x: wp.array(dtype=wp.vec2), y: wp.array(dtype=wp.mat22)):
1892
+ i = wp.tid()
1893
+
1894
+ a = wp.mat22()
1895
+ a[0] = 1.0 * x[i]
1896
+ a[1] = 2.0 * x[i]
1897
+
1898
+ y[i] = a
1899
+
1900
+
1901
+ def test_mat_assign(test, device):
1902
+ # matrix element
1903
+ x = wp.ones(1, dtype=float, requires_grad=True, device=device)
1904
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
1905
+
1906
+ tape = wp.Tape()
1907
+ with tape:
1908
+ wp.launch(mat_assign_element, 1, inputs=[x], outputs=[y], device=device)
1909
+
1910
+ y.grad = wp.ones_like(y)
1911
+ tape.backward()
1912
+
1913
+ assert_np_equal(y.numpy(), np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=float))
1914
+ assert_np_equal(x.grad.numpy(), np.array([10.0], dtype=float))
1915
+
1916
+ # matrix row
1917
+ x = wp.ones(1, dtype=wp.vec2, requires_grad=True, device=device)
1918
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
1919
+
1920
+ tape = wp.Tape()
1921
+ with tape:
1922
+ wp.launch(mat_assign_row, 1, inputs=[x], outputs=[y], device=device)
1923
+
1924
+ y.grad = wp.ones_like(y)
1925
+ tape.backward()
1926
+
1927
+ assert_np_equal(y.numpy(), np.array([[[1.0, 1.0], [2.0, 2.0]]], dtype=float))
1928
+ assert_np_equal(x.grad.numpy(), np.array([[3.0, 3.0]], dtype=float))
1929
+
1930
+
1931
+ def test_matrix_assign_copy(test, device):
1932
+ saved_enable_vector_component_overwrites_setting = wp.config.enable_vector_component_overwrites
1933
+ try:
1934
+ wp.config.enable_vector_component_overwrites = True
1935
+
1936
+ @wp.kernel
1937
+ def mat_in_register_overwrite(x: wp.array2d(dtype=wp.mat22), y: wp.array(dtype=wp.vec2)):
1938
+ i, j = wp.tid()
1939
+
1940
+ a = wp.mat22()
1941
+ a[0] = y[i]
1942
+ a[0, 1] = 3.0
1943
+ x[i, j] = a
1944
+
1945
+ x = wp.zeros((1, 1), dtype=wp.mat22, device=device, requires_grad=True)
1946
+ y = wp.ones(1, dtype=wp.vec2, device=device, requires_grad=True)
1947
+
1948
+ tape = wp.Tape()
1949
+ with tape:
1950
+ wp.launch(mat_in_register_overwrite, dim=(1, 1), inputs=[x, y], device=device)
1951
+
1952
+ tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1953
+
1954
+ assert_np_equal(x.numpy(), np.array([[[[1.0, 3.0], [0.0, 0.0]]]], dtype=float))
1955
+ assert_np_equal(y.grad.numpy(), np.array([[1.0, 0.0]], dtype=float))
1956
+
1957
+ finally:
1958
+ wp.config.enable_vector_component_overwrites = saved_enable_vector_component_overwrites_setting
1959
+
1960
+
1961
+ @wp.kernel
1962
+ def mat_array_extract_element(x: wp.array2d(dtype=wp.mat22), y: wp.array2d(dtype=float)):
1963
+ i, j = wp.tid()
1964
+ a = x[i, j][0, 0]
1965
+ b = x[i, j][0, 1]
1966
+ c = x[i, j][1, 0]
1967
+ d = x[i, j][1, 1]
1968
+ y[i, j] = 1.0 * a + 2.0 * b + 3.0 * c + 4.0 * d
1969
+
1970
+
1971
+ @wp.kernel
1972
+ def mat_array_extract_row(x: wp.array2d(dtype=wp.mat22), y: wp.array2d(dtype=wp.vec2)):
1973
+ i, j = wp.tid()
1974
+ a = x[i, j][0]
1975
+ b = x[i, j][1]
1976
+ y[i, j] = 1.0 * a + 2.0 * b
1977
+
1978
+
1979
+ def test_mat_array_extract(test, device):
1980
+ # matrix element
1981
+ x = wp.ones((1, 1), dtype=wp.mat22, requires_grad=True, device=device)
1982
+ y = wp.zeros((1, 1), dtype=float, requires_grad=True, device=device)
1983
+
1984
+ tape = wp.Tape()
1985
+ with tape:
1986
+ wp.launch(mat_array_extract_element, (1, 1), inputs=[x], outputs=[y], device=device)
1987
+
1988
+ y.grad = wp.ones_like(y)
1989
+ tape.backward()
1990
+
1991
+ assert_np_equal(y.numpy(), np.array([[10.0]], dtype=float))
1992
+ assert_np_equal(x.grad.numpy(), np.array([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=float))
1993
+
1994
+ # matrix row
1995
+ x = wp.ones((1, 1), dtype=wp.mat22, requires_grad=True, device=device)
1996
+ y = wp.zeros((1, 1), dtype=wp.vec2, requires_grad=True, device=device)
1997
+
1998
+ tape = wp.Tape()
1999
+ with tape:
2000
+ wp.launch(mat_array_extract_row, (1, 1), inputs=[x], outputs=[y], device=device)
2001
+
2002
+ y.grad = wp.ones_like(y)
2003
+ tape.backward()
2004
+
2005
+ assert_np_equal(y.numpy(), np.array([[[3.0, 3.0]]], dtype=float))
2006
+ assert_np_equal(x.grad.numpy(), np.array([[[[1.0, 1.0], [2.0, 2.0]]]], dtype=float))
2007
+
2008
+
2009
+ """ TODO: gradient propagation for in-place array assignment
2010
+ @wp.kernel
2011
+ def mat_array_assign_element(x: wp.array2d(dtype=float), y: wp.array2d(dtype=wp.mat22)):
2012
+ i, j = wp.tid()
2013
+
2014
+ y[i, j][0, 0] = 1.0 * x[i, j]
2015
+ y[i, j][0, 1] = 2.0 * x[i, j]
2016
+ y[i, j][1, 0] = 3.0 * x[i, j]
2017
+ y[i, j][1, 1] = 4.0 * x[i, j]
2018
+
2019
+
2020
+ @wp.kernel
2021
+ def mat_array_assign_row(x: wp.array2d(dtype=wp.vec3), y: wp.array2d(dtype=wp.mat(shape=(2, 3), dtype=float))):
2022
+ i, j = wp.tid()
2023
+
2024
+ y[i, j][0] = 1.0 * x[i, j]
2025
+ y[i, j][1] = 2.0 * x[i, j]
2026
+
2027
+
2028
+ def test_mat_array_assign(test, device):
2029
+ # matrix element
2030
+ x = wp.ones((1, 1), dtype=float, requires_grad=True, device=device)
2031
+ y = wp.zeros((1, 1), dtype=wp.mat22, requires_grad=True, device=device)
2032
+
2033
+ tape = wp.Tape()
2034
+ with tape:
2035
+ wp.launch(mat_array_assign_element, (1, 1), inputs=[x], outputs=[y], device=device)
2036
+
2037
+ y.grad = wp.ones_like(y)
2038
+ tape.backward()
2039
+
2040
+ assert_np_equal(y.numpy(), np.array([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=float))
2041
+ assert_np_equal(x.grad.numpy(), np.array([[10.0]], dtype=float))
2042
+
2043
+ # matrix row
2044
+ x = wp.ones((1, 1), dtype=wp.vec3, requires_grad=True, device=device)
2045
+ y = wp.zeros((1, 1), dtype=wp.mat(shape=(2, 3), dtype=float), requires_grad=True, device=device)
2046
+
2047
+ tape = wp.Tape()
2048
+ with tape:
2049
+ wp.launch(mat_array_assign_row, (1, 1), inputs=[x], outputs=[y], device=device)
2050
+
2051
+ y.grad = wp.ones_like(y)
2052
+ tape.backward()
2053
+
2054
+ assert_np_equal(y.numpy(), np.array([[[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]]], dtype=float))
2055
+ assert_np_equal(x.grad.numpy(), np.array([[[3.0, 3.0, 3.0]]], dtype=float))
2056
+ """
2057
+
2058
+
2059
+ @wp.kernel
2060
+ def mat_add_inplace_element(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.mat22)):
2061
+ i = wp.tid()
2062
+
2063
+ a = wp.mat22()
2064
+ b = x[i]
2065
+
2066
+ a[0, 0] += 1.0 * b[0, 0]
2067
+ a[0, 1] += 2.0 * b[0, 1]
2068
+ a[1, 0] += 3.0 * b[1, 0]
2069
+ a[1, 1] += 4.0 * b[1, 1]
2070
+
2071
+ y[i] = a
2072
+
2073
+
2074
+ @wp.kernel
2075
+ def mat_add_inplace_row(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.mat22)):
2076
+ i = wp.tid()
2077
+
2078
+ a = wp.mat22()
2079
+ b = x[i]
2080
+
2081
+ a[0] += 1.0 * b[0]
2082
+ a[1] += 2.0 * b[1]
2083
+
2084
+ y[i] = a
2085
+
2086
+
2087
+ def test_mat_add_inplace(test, device):
2088
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
2089
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
2090
+
2091
+ tape = wp.Tape()
2092
+ with tape:
2093
+ wp.launch(mat_add_inplace_element, 1, inputs=[x], outputs=[y], device=device)
2094
+
2095
+ y.grad = wp.ones_like(y)
2096
+ tape.backward()
2097
+
2098
+ assert_np_equal(y.numpy(), np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=float))
2099
+ assert_np_equal(x.grad.numpy(), np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=float))
2100
+
2101
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
2102
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
1752
2103
 
1753
- m1[0, 0] += m2[0, 0]
1754
- m1[0, 1] += m2[0, 1]
1755
- m1[1, 0] += m2[1, 0]
1756
- m1[1, 1] += m2[1, 1]
2104
+ tape = wp.Tape()
2105
+ with tape:
2106
+ wp.launch(mat_add_inplace_row, 1, inputs=[x], outputs=[y], device=device)
2107
+
2108
+ y.grad = wp.ones_like(y)
2109
+ tape.backward()
1757
2110
 
1758
- a[i] = m1
2111
+ assert_np_equal(y.numpy(), np.array([[[1.0, 1.0], [2.0, 2.0]]], dtype=float))
2112
+ assert_np_equal(x.grad.numpy(), np.array([[[1.0, 1.0], [2.0, 2.0]]], dtype=float))
1759
2113
 
1760
- m3 = wp.mat22()
1761
- m4 = d[i]
1762
2114
 
1763
- m3[0, 0] -= m4[0, 0]
1764
- m3[0, 1] -= m4[0, 1]
1765
- m3[1, 0] -= m4[1, 0]
1766
- m3[1, 1] -= m4[1, 1]
2115
+ @wp.kernel
2116
+ def mat_sub_inplace_element(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.mat22)):
2117
+ i = wp.tid()
1767
2118
 
1768
- c[i] = m3
2119
+ a = wp.mat22()
2120
+ b = x[i]
1769
2121
 
2122
+ a[0, 0] -= 1.0 * b[0, 0]
2123
+ a[0, 1] -= 2.0 * b[0, 1]
2124
+ a[1, 0] -= 3.0 * b[1, 0]
2125
+ a[1, 1] -= 4.0 * b[1, 1]
1770
2126
 
1771
- def test_matrix_augassign(test, device):
1772
- N = 3
2127
+ y[i] = a
1773
2128
 
1774
- a = wp.zeros(N, dtype=wp.mat22, requires_grad=True)
1775
- b = wp.ones(N, dtype=wp.mat22, requires_grad=True)
1776
2129
 
1777
- c = wp.zeros(N, dtype=wp.mat22, requires_grad=True)
1778
- d = wp.ones(N, dtype=wp.mat22, requires_grad=True)
2130
+ @wp.kernel
2131
+ def mat_sub_inplace_row(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.mat22)):
2132
+ i = wp.tid()
2133
+
2134
+ a = wp.mat22()
2135
+ b = x[i]
2136
+
2137
+ a[0] -= 1.0 * b[0]
2138
+ a[1] -= 2.0 * b[1]
2139
+
2140
+ y[i] = a
2141
+
2142
+
2143
+ def test_mat_sub_inplace(test, device):
2144
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
2145
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
1779
2146
 
1780
2147
  tape = wp.Tape()
1781
2148
  with tape:
1782
- wp.launch(matrix_augassign_kernel, N, inputs=[a, b, c, d])
2149
+ wp.launch(mat_sub_inplace_element, 1, inputs=[x], outputs=[y], device=device)
2150
+
2151
+ y.grad = wp.ones_like(y)
2152
+ tape.backward()
2153
+
2154
+ assert_np_equal(y.numpy(), np.array([[[-1.0, -2.0], [-3.0, -4.0]]], dtype=float))
2155
+ assert_np_equal(x.grad.numpy(), np.array([[[-1.0, -2.0], [-3.0, -4.0]]], dtype=float))
1783
2156
 
1784
- tape.backward(grads={a: wp.ones_like(a), c: wp.ones_like(c)})
2157
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
2158
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
1785
2159
 
1786
- assert_np_equal(a.numpy(), wp.ones_like(a).numpy())
1787
- assert_np_equal(a.grad.numpy(), wp.ones_like(a).numpy())
1788
- assert_np_equal(b.grad.numpy(), wp.ones_like(a).numpy())
2160
+ tape = wp.Tape()
2161
+ with tape:
2162
+ wp.launch(mat_sub_inplace_row, 1, inputs=[x], outputs=[y], device=device)
2163
+
2164
+ y.grad = wp.ones_like(y)
2165
+ tape.backward()
1789
2166
 
1790
- assert_np_equal(c.numpy(), -wp.ones_like(c).numpy())
1791
- assert_np_equal(c.grad.numpy(), wp.ones_like(c).numpy())
1792
- assert_np_equal(d.grad.numpy(), -wp.ones_like(d).numpy())
2167
+ assert_np_equal(y.numpy(), np.array([[[-1.0, -1.0], [-2.0, -2.0]]], dtype=float))
2168
+ assert_np_equal(x.grad.numpy(), np.array([[[-1.0, -1.0], [-2.0, -2.0]]], dtype=float))
2169
+
2170
+
2171
+ @wp.kernel
2172
+ def mat_array_add_inplace(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.mat22)):
2173
+ i = wp.tid()
2174
+
2175
+ y[i] += x[i]
2176
+
2177
+
2178
+ def test_mat_array_add_inplace(test, device):
2179
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
2180
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
2181
+
2182
+ tape = wp.Tape()
2183
+ with tape:
2184
+ wp.launch(mat_array_add_inplace, 1, inputs=[x], outputs=[y], device=device)
2185
+
2186
+ y.grad = wp.ones_like(y)
2187
+ tape.backward()
2188
+
2189
+ assert_np_equal(y.numpy(), np.array([[[1.0, 1.0], [1.0, 1.0]]], dtype=float))
2190
+ assert_np_equal(x.grad.numpy(), np.array([[[1.0, 1.0], [1.0, 1.0]]], dtype=float))
2191
+
2192
+
2193
+ @wp.kernel
2194
+ def mat_array_sub_inplace(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.mat22)):
2195
+ i = wp.tid()
2196
+
2197
+ y[i] -= x[i]
2198
+
2199
+
2200
+ def test_mat_array_sub_inplace(test, device):
2201
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
2202
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
2203
+
2204
+ tape = wp.Tape()
2205
+ with tape:
2206
+ wp.launch(mat_array_sub_inplace, 1, inputs=[x], outputs=[y], device=device)
2207
+
2208
+ y.grad = wp.ones_like(y)
2209
+ tape.backward()
2210
+
2211
+ assert_np_equal(y.numpy(), np.array([[[-1.0, -1.0], [-1.0, -1.0]]], dtype=float))
2212
+ assert_np_equal(x.grad.numpy(), np.array([[[-1.0, -1.0], [-1.0, -1.0]]], dtype=float))
1793
2213
 
1794
2214
 
1795
2215
  devices = get_test_devices()
@@ -1814,6 +2234,7 @@ add_kernel_test(TestMat, test_constructors_explicit_precision, dim=1, devices=de
1814
2234
  add_kernel_test(TestMat, test_constructors_default_precision, dim=1, devices=devices)
1815
2235
  add_kernel_test(TestMat, test_constructors_constant_shape, dim=1, devices=devices)
1816
2236
  add_kernel_test(TestMat, test_matrix_constructor_value_func, dim=1, devices=devices)
2237
+ add_kernel_test(TestMat, test_matrix_from_vecs, dim=1, devices=devices)
1817
2238
 
1818
2239
  mat103 = wp.types.matrix(shape=(10, 3), dtype=float)
1819
2240
  add_kernel_test(
@@ -1848,6 +2269,12 @@ for dtype in np_signed_int_types + np_float_types:
1848
2269
  TestMat, f"test_matmul_{dtype.__name__}", test_matmul, devices=devices, dtype=dtype
1849
2270
  )
1850
2271
 
2272
+ add_function_test(
2273
+ TestMat,
2274
+ "test_shape_mismatch",
2275
+ test_shape_mismatch,
2276
+ devices=devices,
2277
+ )
1851
2278
  add_function_test(
1852
2279
  TestMat,
1853
2280
  "test_anon_constructor_error_shape_arg_missing",
@@ -1878,18 +2305,6 @@ add_function_test(
1878
2305
  test_tpl_constructor_error_incompatible_sizes,
1879
2306
  devices=devices,
1880
2307
  )
1881
- add_function_test(
1882
- TestMat,
1883
- "test_tpl_constructor_error_invalid_vector_count",
1884
- test_tpl_constructor_error_invalid_vector_count,
1885
- devices=devices,
1886
- )
1887
- add_function_test(
1888
- TestMat,
1889
- "test_tpl_constructor_error_invalid_vector_shape",
1890
- test_tpl_constructor_error_invalid_vector_shape,
1891
- devices=devices,
1892
- )
1893
2308
  add_function_test(
1894
2309
  TestMat,
1895
2310
  "test_tpl_constructor_error_invalid_arg_count",
@@ -1908,6 +2323,9 @@ for dtype in np_float_types:
1908
2323
  TestMat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
1909
2324
  )
1910
2325
  add_function_test_register_kernel(TestMat, f"test_svd_{dtype.__name__}", test_svd, devices=devices, dtype=dtype)
2326
+ add_function_test_register_kernel(
2327
+ TestMat, f"test_svd_2D{dtype.__name__}", test_svd_2D, devices=devices, dtype=dtype
2328
+ )
1911
2329
  add_function_test_register_kernel(TestMat, f"test_qr_{dtype.__name__}", test_qr, devices=devices, dtype=dtype)
1912
2330
  add_function_test_register_kernel(TestMat, f"test_eig_{dtype.__name__}", test_eig, devices=devices, dtype=dtype)
1913
2331
  add_function_test_register_kernel(
@@ -1920,15 +2338,18 @@ for dtype in np_float_types:
1920
2338
  TestMat, f"test_determinant_{dtype.__name__}", test_determinant, devices=devices, dtype=dtype
1921
2339
  )
1922
2340
  add_function_test_register_kernel(TestMat, f"test_skew_{dtype.__name__}", test_skew, devices=devices, dtype=dtype)
1923
- add_function_test_register_kernel(
1924
- TestMat,
1925
- f"test_mat_array_type_indexing_{dtype.__name__}",
1926
- test_mat_array_type_indexing,
1927
- devices=devices,
1928
- dtype=dtype,
1929
- )
2341
+
1930
2342
  add_function_test(TestMat, "test_matrix_len", test_matrix_len, devices=devices)
1931
- add_function_test(TestMat, "test_matrix_augassign", test_matrix_augassign, devices=devices)
2343
+ add_function_test(TestMat, "test_mat_extract", test_mat_extract, devices=devices)
2344
+ add_function_test(TestMat, "test_mat_assign", test_mat_assign, devices=devices)
2345
+ add_function_test(TestMat, "test_matrix_assign_copy", test_matrix_assign_copy, devices=devices)
2346
+ add_function_test(TestMat, "test_mat_array_extract", test_mat_array_extract, devices=devices)
2347
+ # add_function_test(TestMat, "test_mat_array_assign", test_mat_array_assign, devices=devices)
2348
+ add_function_test(TestMat, "test_mat_add_inplace", test_mat_add_inplace, devices=devices)
2349
+ add_function_test(TestMat, "test_mat_sub_inplace", test_mat_sub_inplace, devices=devices)
2350
+ add_function_test(TestMat, "test_mat_array_add_inplace", test_mat_array_add_inplace, devices=devices)
2351
+ add_function_test(TestMat, "test_mat_array_sub_inplace", test_mat_array_sub_inplace, devices=devices)
2352
+
1932
2353
 
1933
2354
  if __name__ == "__main__":
1934
2355
  wp.clear_kernel_cache()