warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl → 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (60) hide show
  1. warp/autograd.py +12 -2
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +1 -1
  5. warp/builtins.py +103 -66
  6. warp/codegen.py +48 -27
  7. warp/config.py +1 -1
  8. warp/context.py +112 -49
  9. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  10. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  11. warp/fem/cache.py +1 -1
  12. warp/fem/field/field.py +11 -1
  13. warp/fem/field/nodal_field.py +36 -22
  14. warp/fem/geometry/adaptive_nanogrid.py +7 -3
  15. warp/fem/geometry/trimesh.py +4 -12
  16. warp/jax_experimental/custom_call.py +14 -2
  17. warp/jax_experimental/ffi.py +100 -67
  18. warp/native/builtin.h +91 -65
  19. warp/native/svd.h +59 -49
  20. warp/native/tile.h +55 -26
  21. warp/native/volume.cpp +2 -2
  22. warp/native/volume_builder.cu +33 -22
  23. warp/native/warp.cu +1 -1
  24. warp/render/render_opengl.py +41 -34
  25. warp/render/render_usd.py +96 -6
  26. warp/sim/collide.py +11 -9
  27. warp/sim/inertia.py +189 -156
  28. warp/sim/integrator_euler.py +3 -0
  29. warp/sim/integrator_xpbd.py +3 -0
  30. warp/sim/model.py +56 -31
  31. warp/sim/render.py +4 -0
  32. warp/sparse.py +1 -1
  33. warp/stubs.py +73 -25
  34. warp/tests/assets/torus.usda +1 -1
  35. warp/tests/cuda/test_streams.py +1 -1
  36. warp/tests/sim/test_collision.py +237 -206
  37. warp/tests/sim/test_inertia.py +161 -0
  38. warp/tests/sim/test_model.py +5 -3
  39. warp/tests/sim/{flaky_test_sim_grad.py → test_sim_grad.py} +1 -4
  40. warp/tests/sim/test_xpbd.py +399 -0
  41. warp/tests/test_array.py +8 -7
  42. warp/tests/test_atomic.py +181 -2
  43. warp/tests/test_builtins_resolution.py +38 -38
  44. warp/tests/test_codegen.py +24 -3
  45. warp/tests/test_examples.py +16 -6
  46. warp/tests/test_fem.py +93 -14
  47. warp/tests/test_func.py +1 -1
  48. warp/tests/test_mat.py +416 -119
  49. warp/tests/test_quat.py +321 -137
  50. warp/tests/test_struct.py +116 -0
  51. warp/tests/test_vec.py +320 -174
  52. warp/tests/tile/test_tile.py +27 -0
  53. warp/tests/tile/test_tile_load.py +124 -0
  54. warp/tests/unittest_suites.py +2 -5
  55. warp/types.py +107 -9
  56. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2rc1.dist-info}/METADATA +41 -19
  57. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2rc1.dist-info}/RECORD +60 -57
  58. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2rc1.dist-info}/WHEEL +1 -1
  59. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2rc1.dist-info}/licenses/LICENSE.md +0 -26
  60. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2rc1.dist-info}/top_level.txt +0 -0
warp/tests/test_quat.py CHANGED
@@ -1903,89 +1903,6 @@ def test_quat_identity(test, device, dtype, register_kernels=False):
1903
1903
  assert_np_equal(output.numpy(), expected)
1904
1904
 
1905
1905
 
1906
- ############################################################
1907
-
1908
-
1909
- def test_quat_assign_inplace(test, device, dtype, register_kernels=False):
1910
- np_type = np.dtype(dtype)
1911
- wp_type = wp.types.np_dtype_to_warp_type[np_type]
1912
-
1913
- quat = wp.types.quaternion(dtype=wp_type)
1914
-
1915
- def quattest_read_write_store(x: wp.array(dtype=wp_type), a: wp.array(dtype=quat)):
1916
- tid = wp.tid()
1917
-
1918
- t = a[tid]
1919
- t[0] = x[tid]
1920
- a[tid] = t
1921
-
1922
- def quattest_in_register(x: wp.array(dtype=wp_type), a: wp.array(dtype=quat)):
1923
- tid = wp.tid()
1924
-
1925
- g = wp_type(0.0)
1926
- q = a[tid]
1927
- g = q[0] + wp_type(2.0) * q[1] + wp_type(3.0) * q[2] + wp_type(4.0) * q[3]
1928
- x[tid] = g
1929
-
1930
- def quattest_component(x: wp.array(dtype=quat), y: wp.array(dtype=wp_type)):
1931
- i = wp.tid()
1932
-
1933
- a = quat()
1934
- a.x = wp_type(1.0) * y[i]
1935
- a.y = wp_type(2.0) * y[i]
1936
- a.z = wp_type(3.0) * y[i]
1937
- a.w = wp_type(4.0) * y[i]
1938
- x[i] = a
1939
-
1940
- kernel_read_write_store = getkernel(quattest_read_write_store, suffix=dtype.__name__)
1941
- kernel_in_register = getkernel(quattest_in_register, suffix=dtype.__name__)
1942
- kernel_component = getkernel(quattest_component, suffix=dtype.__name__)
1943
-
1944
- if register_kernels:
1945
- return
1946
-
1947
- a = wp.ones(1, dtype=quat, device=device, requires_grad=True)
1948
- x = wp.full(1, value=2.0, dtype=wp_type, device=device, requires_grad=True)
1949
-
1950
- tape = wp.Tape()
1951
- with tape:
1952
- wp.launch(kernel_read_write_store, dim=1, inputs=[x, a], device=device)
1953
-
1954
- tape.backward(grads={a: wp.ones_like(a, requires_grad=False)})
1955
-
1956
- assert_np_equal(a.numpy(), np.array([[2.0, 1.0, 1.0, 1.0]], dtype=np_type))
1957
- assert_np_equal(x.grad.numpy(), np.array([1.0], dtype=np_type))
1958
-
1959
- tape.reset()
1960
-
1961
- a = wp.ones(1, dtype=quat, device=device, requires_grad=True)
1962
- x = wp.zeros(1, dtype=wp_type, device=device, requires_grad=True)
1963
-
1964
- with tape:
1965
- wp.launch(kernel_in_register, dim=1, inputs=[x, a], device=device)
1966
-
1967
- tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1968
-
1969
- assert_np_equal(x.numpy(), np.array([10.0], dtype=np_type))
1970
- assert_np_equal(a.grad.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np_type))
1971
-
1972
- tape.reset()
1973
-
1974
- x = wp.zeros(1, dtype=quat, requires_grad=True)
1975
- y = wp.ones(1, dtype=wp_type, requires_grad=True)
1976
-
1977
- with tape:
1978
- wp.launch(kernel_component, dim=1, inputs=[x, y])
1979
-
1980
- tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1981
-
1982
- assert_np_equal(x.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np_type))
1983
- assert_np_equal(y.grad.numpy(), np.array([10.0], dtype=np_type))
1984
-
1985
-
1986
- ############################################################
1987
-
1988
-
1989
1906
  def test_quat_euler_conversion(test, device, dtype, register_kernels=False):
1990
1907
  rng = np.random.default_rng(123)
1991
1908
  N = 3
@@ -2065,6 +1982,12 @@ def test_constructor_default():
2065
1982
  wp.expect_eq(qeye[2], 0.0)
2066
1983
  wp.expect_eq(qeye[3], 1.0)
2067
1984
 
1985
+ qlit = wp.quaternion(1.0, 2.0, 3.0, 4.0, dtype=float)
1986
+ wp.expect_eq(qlit[0], 1.0)
1987
+ wp.expect_eq(qlit[1], 2.0)
1988
+ wp.expect_eq(qlit[2], 3.0)
1989
+ wp.expect_eq(qlit[3], 4.0)
1990
+
2068
1991
 
2069
1992
  def test_py_arithmetic_ops(test, device, dtype):
2070
1993
  wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
@@ -2116,54 +2039,85 @@ def test_quat_len(test, device):
2116
2039
 
2117
2040
 
2118
2041
  @wp.kernel
2119
- def quat_augassign_kernel(
2120
- a: wp.array(dtype=wp.quat), b: wp.array(dtype=wp.quat), c: wp.array(dtype=wp.quat), d: wp.array(dtype=wp.quat)
2121
- ):
2122
- i = wp.tid()
2042
+ def quat_extract_subscript(x: wp.array(dtype=wp.quat), y: wp.array(dtype=float)):
2043
+ tid = wp.tid()
2123
2044
 
2124
- q1 = wp.quat()
2125
- q2 = b[i]
2045
+ a = x[tid]
2046
+ b = a[0] + 2.0 * a[1] + 3.0 * a[2] + 4.0 * a[3]
2047
+ y[tid] = b
2126
2048
 
2127
- q1[0] += q2[0]
2128
- q1[1] += q2[1]
2129
- q1[2] += q2[2]
2130
- q1[3] += q2[3]
2131
2049
 
2132
- a[i] = q1
2050
+ """ TODO: rhs attribute indexing
2051
+ @wp.kernel
2052
+ def quat_extract_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=float)):
2053
+ tid = wp.tid()
2133
2054
 
2134
- q3 = wp.quat()
2135
- q4 = d[i]
2055
+ a = x[tid]
2056
+ b = a.x + float(2.0) * a.y + 3.0 * a.z + 4.0 * a.w
2057
+ y[tid] = b
2058
+ """
2136
2059
 
2137
- q3[0] -= q4[0]
2138
- q3[1] -= q4[1]
2139
- q3[2] -= q4[2]
2140
- q3[3] -= q4[3]
2141
2060
 
2142
- c[i] = q3
2061
+ def test_quat_extract(test, device):
2062
+ def run(kernel):
2063
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2064
+ y = wp.zeros(1, dtype=float, requires_grad=True, device=device)
2143
2065
 
2066
+ tape = wp.Tape()
2067
+ with tape:
2068
+ wp.launch(kernel, 1, inputs=[x], outputs=[y], device=device)
2144
2069
 
2145
- def test_quat_augassign(test, device):
2146
- N = 3
2070
+ y.grad = wp.ones_like(y)
2071
+ tape.backward()
2147
2072
 
2148
- a = wp.zeros(N, dtype=wp.quat, requires_grad=True, device=device)
2149
- b = wp.ones(N, dtype=wp.quat, requires_grad=True, device=device)
2073
+ assert_np_equal(y.numpy(), np.array([10.0], dtype=float))
2074
+ assert_np_equal(x.grad.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=float))
2150
2075
 
2151
- c = wp.zeros(N, dtype=wp.quat, requires_grad=True, device=device)
2152
- d = wp.ones(N, dtype=wp.quat, requires_grad=True, device=device)
2076
+ run(quat_extract_subscript)
2077
+ # run(quat_extract_attribute)
2153
2078
 
2154
- tape = wp.Tape()
2155
- with tape:
2156
- wp.launch(quat_augassign_kernel, N, inputs=[a, b, c, d], device=device)
2157
2079
 
2158
- tape.backward(grads={a: wp.ones_like(a), c: wp.ones_like(c)})
2080
+ @wp.kernel
2081
+ def quat_assign_subscript(x: wp.array(dtype=float), y: wp.array(dtype=wp.quat)):
2082
+ i = wp.tid()
2083
+
2084
+ a = wp.quat()
2085
+ a[0] = 1.0 * x[i]
2086
+ a[1] = 2.0 * x[i]
2087
+ a[2] = 3.0 * x[i]
2088
+ a[3] = 4.0 * x[i]
2089
+ y[i] = a
2090
+
2091
+
2092
+ @wp.kernel
2093
+ def quat_assign_attribute(x: wp.array(dtype=float), y: wp.array(dtype=wp.quat)):
2094
+ i = wp.tid()
2095
+
2096
+ a = wp.quat()
2097
+ a.x = 1.0 * x[i]
2098
+ a.y = 2.0 * x[i]
2099
+ a.z = 3.0 * x[i]
2100
+ a.w = 4.0 * x[i]
2101
+ y[i] = a
2102
+
2103
+
2104
+ def test_quat_assign(test, device):
2105
+ def run(kernel):
2106
+ x = wp.ones(1, dtype=float, requires_grad=True, device=device)
2107
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2108
+
2109
+ tape = wp.Tape()
2110
+ with tape:
2111
+ wp.launch(kernel, 1, inputs=[x], outputs=[y], device=device)
2112
+
2113
+ y.grad = wp.ones_like(y)
2114
+ tape.backward()
2159
2115
 
2160
- assert_np_equal(a.numpy(), wp.ones_like(a).numpy())
2161
- assert_np_equal(a.grad.numpy(), wp.ones_like(a).numpy())
2162
- assert_np_equal(b.grad.numpy(), wp.ones_like(a).numpy())
2116
+ assert_np_equal(y.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=float))
2117
+ assert_np_equal(x.grad.numpy(), np.array([10.0], dtype=float))
2163
2118
 
2164
- assert_np_equal(c.numpy(), -wp.ones_like(c).numpy())
2165
- assert_np_equal(c.grad.numpy(), wp.ones_like(c).numpy())
2166
- assert_np_equal(d.grad.numpy(), -wp.ones_like(d).numpy())
2119
+ run(quat_assign_subscript)
2120
+ run(quat_assign_attribute)
2167
2121
 
2168
2122
 
2169
2123
  def test_quat_assign_copy(test, device):
@@ -2172,32 +2126,261 @@ def test_quat_assign_copy(test, device):
2172
2126
  wp.config.enable_vector_component_overwrites = True
2173
2127
 
2174
2128
  @wp.kernel
2175
- def quat_in_register_overwrite(x: wp.array(dtype=wp.quat), a: wp.array(dtype=wp.quat)):
2129
+ def quat_assign_overwrite(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2176
2130
  tid = wp.tid()
2177
2131
 
2178
- f = wp.quat()
2179
- a_quat = a[tid]
2180
- f = a_quat
2181
- f[1] = 3.0
2132
+ a = wp.quat()
2133
+ b = x[tid]
2134
+ a = b
2135
+ a[1] = 3.0
2182
2136
 
2183
- x[tid] = f
2137
+ y[tid] = a
2184
2138
 
2185
- x = wp.zeros(1, dtype=wp.quat, device=device, requires_grad=True)
2186
- a = wp.ones(1, dtype=wp.quat, device=device, requires_grad=True)
2139
+ x = wp.ones(1, dtype=wp.quat, device=device, requires_grad=True)
2140
+ y = wp.zeros(1, dtype=wp.quat, device=device, requires_grad=True)
2187
2141
 
2188
2142
  tape = wp.Tape()
2189
2143
  with tape:
2190
- wp.launch(quat_in_register_overwrite, dim=1, inputs=[x, a], device=device)
2144
+ wp.launch(quat_assign_overwrite, dim=1, inputs=[x, y], device=device)
2191
2145
 
2192
- tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
2146
+ y.grad = wp.ones_like(y, requires_grad=False)
2147
+ tape.backward()
2193
2148
 
2194
- assert_np_equal(x.numpy(), np.array([[1.0, 3.0, 1.0, 1.0]], dtype=float))
2195
- assert_np_equal(a.grad.numpy(), np.array([[1.0, 0.0, 1.0, 1.0]], dtype=float))
2149
+ assert_np_equal(y.numpy(), np.array([[1.0, 3.0, 1.0, 1.0]], dtype=float))
2150
+ assert_np_equal(x.grad.numpy(), np.array([[1.0, 0.0, 1.0, 1.0]], dtype=float))
2196
2151
 
2197
2152
  finally:
2198
2153
  wp.config.enable_vector_component_overwrites = saved_enable_vector_component_overwrites_setting
2199
2154
 
2200
2155
 
2156
+ @wp.kernel
2157
+ def quat_array_extract_subscript(x: wp.array2d(dtype=wp.quat), y: wp.array2d(dtype=float)):
2158
+ i, j = wp.tid()
2159
+ a = x[i, j][0]
2160
+ b = x[i, j][1]
2161
+ c = x[i, j][2]
2162
+ d = x[i, j][3]
2163
+ y[i, j] = 1.0 * a + 2.0 * b + 3.0 * c + 4.0 * d
2164
+
2165
+
2166
+ """ TODO: rhs attribute indexing
2167
+ @wp.kernel
2168
+ def quat_array_extract_attribute(x: wp.array2d(dtype=wp.quat), y: wp.array2d(dtype=float)):
2169
+ i, j = wp.tid()
2170
+ a = x[i, j].x
2171
+ b = x[i, j].y
2172
+ c = x[i, j].z
2173
+ d = x[i, j].w
2174
+ y[i, j] = 1.0 * a + 2.0 * b + 3.0 * c + 4.0 * d
2175
+ """
2176
+
2177
+
2178
+ def test_quat_array_extract(test, device):
2179
+ def run(kernel):
2180
+ x = wp.ones((1, 1), dtype=wp.quat, requires_grad=True, device=device)
2181
+ y = wp.zeros((1, 1), dtype=float, requires_grad=True, device=device)
2182
+
2183
+ tape = wp.Tape()
2184
+ with tape:
2185
+ wp.launch(kernel, (1, 1), inputs=[x], outputs=[y], device=device)
2186
+
2187
+ y.grad = wp.ones_like(y)
2188
+ tape.backward()
2189
+
2190
+ assert_np_equal(y.numpy(), np.array([[10.0]], dtype=float))
2191
+ assert_np_equal(x.grad.numpy(), np.array([[[1.0, 2.0, 3.0, 4.0]]], dtype=float))
2192
+
2193
+ run(quat_array_extract_subscript)
2194
+ # run(quat_array_extract_attribute)
2195
+
2196
+
2197
+ @wp.kernel
2198
+ def quat_array_assign_subscript(x: wp.array2d(dtype=float), y: wp.array2d(dtype=wp.quat)):
2199
+ i, j = wp.tid()
2200
+
2201
+ y[i, j][0] = 1.0 * x[i, j]
2202
+ y[i, j][1] = 2.0 * x[i, j]
2203
+ y[i, j][2] = 3.0 * x[i, j]
2204
+ y[i, j][3] = 4.0 * x[i, j]
2205
+
2206
+
2207
+ @wp.kernel
2208
+ def quat_array_assign_attribute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=wp.quat)):
2209
+ i, j = wp.tid()
2210
+
2211
+ y[i, j].x = 1.0 * x[i, j]
2212
+ y[i, j].y = 2.0 * x[i, j]
2213
+ y[i, j].z = 3.0 * x[i, j]
2214
+ y[i, j].w = 4.0 * x[i, j]
2215
+
2216
+
2217
+ def test_quat_array_assign(test, device):
2218
+ def run(kernel):
2219
+ x = wp.ones((1, 1), dtype=float, requires_grad=True, device=device)
2220
+ y = wp.zeros((1, 1), dtype=wp.quat, requires_grad=True, device=device)
2221
+
2222
+ tape = wp.Tape()
2223
+ with tape:
2224
+ wp.launch(kernel, (1, 1), inputs=[x], outputs=[y], device=device)
2225
+
2226
+ y.grad = wp.ones_like(y)
2227
+ tape.backward()
2228
+
2229
+ assert_np_equal(y.numpy(), np.array([[[1.0, 2.0, 3.0, 4.0]]], dtype=float))
2230
+ # TODO: gradient propagation for in-place array assignment
2231
+ # assert_np_equal(x.grad.numpy(), np.array([[10.0]], dtype=float))
2232
+
2233
+ run(quat_array_assign_subscript)
2234
+ run(quat_array_assign_attribute)
2235
+
2236
+
2237
+ @wp.kernel
2238
+ def quat_add_inplace_subscript(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2239
+ i = wp.tid()
2240
+
2241
+ a = wp.quat()
2242
+ b = x[i]
2243
+
2244
+ a[0] += 1.0 * b[0]
2245
+ a[1] += 2.0 * b[1]
2246
+ a[2] += 3.0 * b[2]
2247
+ a[3] += 4.0 * b[3]
2248
+
2249
+ y[i] = a
2250
+
2251
+
2252
+ """ TODO: rhs attribute indexing
2253
+ @wp.kernel
2254
+ def quat_add_inplace_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2255
+ i = wp.tid()
2256
+
2257
+ a = wp.quat()
2258
+ b = x[i]
2259
+
2260
+ a.x += 1.0 * b.x
2261
+ a.y += 2.0 * b.y
2262
+ a.z += 3.0 * b.z
2263
+ a.w += 4.0 * b.w
2264
+
2265
+ y[i] = a
2266
+ """
2267
+
2268
+
2269
+ def test_quat_add_inplace(test, device):
2270
+ def run(kernel):
2271
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2272
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2273
+
2274
+ tape = wp.Tape()
2275
+ with tape:
2276
+ wp.launch(kernel, 1, inputs=[x], outputs=[y], device=device)
2277
+
2278
+ y.grad = wp.ones_like(y)
2279
+ tape.backward()
2280
+
2281
+ assert_np_equal(y.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=float))
2282
+ assert_np_equal(x.grad.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=float))
2283
+
2284
+ run(quat_add_inplace_subscript)
2285
+ # run(quat_add_inplace_attribute)
2286
+
2287
+
2288
+ @wp.kernel
2289
+ def quat_sub_inplace_subscript(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2290
+ i = wp.tid()
2291
+
2292
+ a = wp.quat()
2293
+ b = x[i]
2294
+
2295
+ a[0] -= 1.0 * b[0]
2296
+ a[1] -= 2.0 * b[1]
2297
+ a[2] -= 3.0 * b[2]
2298
+ a[3] -= 4.0 * b[3]
2299
+
2300
+ y[i] = a
2301
+
2302
+
2303
+ """ TODO: rhs attribute indexing
2304
+ @wp.kernel
2305
+ def quat_sub_inplace_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2306
+ i = wp.tid()
2307
+
2308
+ a = wp.quat()
2309
+ b = x[i]
2310
+
2311
+ a.x -= 1.0 * b.x
2312
+ a.y -= 2.0 * b.y
2313
+ a.z -= 3.0 * b.z
2314
+ a.w -= 4.0 * b.w
2315
+
2316
+ y[i] = a
2317
+ """
2318
+
2319
+
2320
+ def test_quat_sub_inplace(test, device):
2321
+ def run(kernel):
2322
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2323
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2324
+
2325
+ tape = wp.Tape()
2326
+ with tape:
2327
+ wp.launch(kernel, 1, inputs=[x], outputs=[y], device=device)
2328
+
2329
+ y.grad = wp.ones_like(y)
2330
+ tape.backward()
2331
+
2332
+ assert_np_equal(y.numpy(), np.array([[-1.0, -2.0, -3.0, -4.0]], dtype=float))
2333
+ assert_np_equal(x.grad.numpy(), np.array([[-1.0, -2.0, -3.0, -4.0]], dtype=float))
2334
+
2335
+ run(quat_sub_inplace_subscript)
2336
+ # run(quat_sub_inplace_attribute)
2337
+
2338
+
2339
+ @wp.kernel
2340
+ def quat_array_add_inplace(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2341
+ i = wp.tid()
2342
+
2343
+ y[i] += x[i]
2344
+
2345
+
2346
+ def test_quat_array_add_inplace(test, device):
2347
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2348
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2349
+
2350
+ tape = wp.Tape()
2351
+ with tape:
2352
+ wp.launch(quat_array_add_inplace, 1, inputs=[x], outputs=[y], device=device)
2353
+
2354
+ y.grad = wp.ones_like(y)
2355
+ tape.backward()
2356
+
2357
+ assert_np_equal(y.numpy(), np.array([[1.0, 1.0, 1.0, 1.0]], dtype=float))
2358
+ assert_np_equal(x.grad.numpy(), np.array([[1.0, 1.0, 1.0, 1.0]], dtype=float))
2359
+
2360
+
2361
+ """ TODO: quat negation operator
2362
+ @wp.kernel
2363
+ def quat_array_sub_inplace(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2364
+ i = wp.tid()
2365
+
2366
+ y[i] -= x[i]
2367
+
2368
+
2369
+ def test_quat_array_sub_inplace(test, device):
2370
+ x = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2371
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2372
+
2373
+ tape = wp.Tape()
2374
+ with tape:
2375
+ wp.launch(quat_array_sub_inplace, 1, inputs=[x], outputs=[y], device=device)
2376
+
2377
+ y.grad = wp.ones_like(y)
2378
+ tape.backward()
2379
+
2380
+ assert_np_equal(y.numpy(), np.array([[-1.0, -1.0, -1.0, -1.0]], dtype=float))
2381
+ assert_np_equal(x.grad.numpy(), np.array([[-1.0, -1.0, -1.0, -1.0]], dtype=float))
2382
+ """
2383
+
2201
2384
  devices = get_test_devices()
2202
2385
 
2203
2386
 
@@ -2295,20 +2478,21 @@ for dtype in np_float_types:
2295
2478
  devices=devices,
2296
2479
  dtype=dtype,
2297
2480
  )
2298
- add_function_test_register_kernel(
2299
- TestQuat,
2300
- f"test_quat_assign_inplace_{dtype.__name__}",
2301
- test_quat_assign_inplace,
2302
- devices=devices,
2303
- dtype=dtype,
2304
- )
2305
2481
  add_function_test(
2306
2482
  TestQuat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2307
2483
  )
2308
2484
 
2309
2485
  add_function_test(TestQuat, "test_quat_len", test_quat_len, devices=devices)
2310
- add_function_test(TestQuat, "test_quat_augassign", test_quat_augassign, devices=devices)
2486
+ add_function_test(TestQuat, "test_quat_extract", test_quat_extract, devices=devices)
2487
+ add_function_test(TestQuat, "test_quat_assign", test_quat_assign, devices=devices)
2311
2488
  add_function_test(TestQuat, "test_quat_assign_copy", test_quat_assign_copy, devices=devices)
2489
+ add_function_test(TestQuat, "test_quat_array_extract", test_quat_array_extract, devices=devices)
2490
+ add_function_test(TestQuat, "test_quat_array_assign", test_quat_array_assign, devices=devices)
2491
+ add_function_test(TestQuat, "test_quat_add_inplace", test_quat_add_inplace, devices=devices)
2492
+ add_function_test(TestQuat, "test_quat_sub_inplace", test_quat_sub_inplace, devices=devices)
2493
+ add_function_test(TestQuat, "test_quat_array_add_inplace", test_quat_array_add_inplace, devices=devices)
2494
+ # add_function_test(TestQuat, "test_quat_array_sub_inplace", test_quat_array_sub_inplace, devices=devices)
2495
+
2312
2496
 
2313
2497
  if __name__ == "__main__":
2314
2498
  wp.clear_kernel_cache()
warp/tests/test_struct.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import gc # Added for garbage collection tests
16
17
  import unittest
17
18
  from typing import Any
18
19
 
@@ -221,6 +222,11 @@ def test_nested_struct(test, device):
221
222
  foo.bar.y = 1.23
222
223
  foo.x = 123
223
224
 
225
+ # verify that struct attributes are instances of their original class
226
+ assert isinstance(foo, Foo.cls)
227
+ assert isinstance(foo.bar, Bar.cls)
228
+ assert isinstance(foo.bar.baz, Baz.cls)
229
+
224
230
  wp.launch(kernel_nested_struct, dim=dim, inputs=[foo], device=device)
225
231
 
226
232
  assert_array_equal(
@@ -243,6 +249,18 @@ def test_struct_attribute_error(test, device):
243
249
  )
244
250
 
245
251
 
252
+ def test_struct_inheritance_error(test, device):
253
+ with test.assertRaisesRegex(RuntimeError, r"Warp structs must be defined as base classes$"):
254
+
255
+ @wp.struct
256
+ class Parent:
257
+ x: int
258
+
259
+ @wp.struct
260
+ class Child(Parent):
261
+ y: int
262
+
263
+
246
264
  @wp.kernel
247
265
  def test_struct_instantiate(data: wp.array(dtype=int)):
248
266
  baz = Baz(data, wp.vec3(0.0, 0.0, 26.0))
@@ -643,6 +661,96 @@ def test_struct_array_hash(test, device):
643
661
  )
644
662
 
645
663
 
664
+ # Tests for garbage collection behavior with arrays in structs
665
+ @wp.struct
666
+ class StructWithArray:
667
+ data: wp.array(dtype=float)
668
+ some_value: int
669
+
670
+
671
+ @wp.kernel
672
+ def access_array_kernel(s: StructWithArray, out: wp.array(dtype=float)):
673
+ # This kernel is used to verify data integrity by reading the first element.
674
+ # Assumes s.data has at least 1 element for this test.
675
+ out[0] = s.data[0]
676
+
677
+
678
+ @wp.kernel
679
+ def compute_loss_from_struct_array_kernel(s_in: StructWithArray, loss_val: wp.array(dtype=float)):
680
+ # Compute a simple scalar loss from the array elements for grad testing.
681
+ # Assumes s_in.data has at least 2 elements for this test.
682
+ res = 0.0
683
+ res += s_in.data[0] * 2.0 # Example weight
684
+ res += s_in.data[1] * 3.0 # Example weight
685
+ loss_val[0] = res
686
+
687
+
688
+ def test_struct_array_gc_direct_assignment(test, device):
689
+ """
690
+ Tests that an array assigned to a struct (with no other direct Python
691
+ references) is not garbage collected prematurely.
692
+ """
693
+ wp.init()
694
+
695
+ s = StructWithArray()
696
+ s.some_value = 20
697
+
698
+ # Create an array, then assign it to the struct.
699
+ # After this assignment, 's.data' is the primary way to access it from
700
+ # Python's perspective, though Warp's context should also hold a reference.
701
+ local_array = wp.array([4.0, 5.0, 6.0], dtype=float, device=device)
702
+ s.data = local_array
703
+ del local_array # Remove the direct Python reference
704
+
705
+ # Force garbage collection
706
+ gc.collect()
707
+
708
+ # Attempt to access the array in a kernel
709
+ out_wp = wp.zeros(1, dtype=float, device=device)
710
+ try:
711
+ wp.launch(kernel=access_array_kernel, dim=1, inputs=[s, out_wp], device=device)
712
+
713
+ # We expect to read 4.0 if the array is still valid
714
+ assert out_wp.numpy()[0] == 4.0, "Array data was not accessible or incorrect after GC with direct assignment."
715
+ except Exception as e:
716
+ test.fail(f"Kernel execution failed after GC with direct assignment: {e}")
717
+
718
+
719
+ def test_struct_array_gc_requires_grad_toggle(test, device):
720
+ """
721
+ Tests that an array within a struct is not garbage collected prematurely
722
+ when its requires_grad flag is toggled, and that backward pass works.
723
+ """
724
+ wp.init()
725
+
726
+ s = StructWithArray()
727
+ s.some_value = 10
728
+ # Initialize array with requires_grad=True. Content: [1.0, 2.0, 3.0]
729
+ s.data = wp.array([1.0, 2.0, 3.0], dtype=float, device=device, requires_grad=True)
730
+
731
+ loss_wp = wp.zeros(1, dtype=float, device=device, requires_grad=True)
732
+
733
+ tape = wp.Tape()
734
+ with tape:
735
+ # Launch kernel that uses s.data to compute a loss
736
+ wp.launch(
737
+ kernel=compute_loss_from_struct_array_kernel,
738
+ dim=1,
739
+ inputs=[s, loss_wp],
740
+ device=device,
741
+ )
742
+
743
+ # Expected loss = 1.0*2.0 + 2.0*3.0 = 2.0 + 6.0 = 8.0
744
+
745
+ # After the forward pass is recorded, toggle requires_grad and run GC
746
+ s.data.requires_grad = False
747
+ gc.collect()
748
+
749
+ # will cause a memory access violation if grad array has been garbage collected
750
+ # or struct is not updated correctly
751
+ tape.backward(loss=loss_wp)
752
+
753
+
646
754
  devices = get_test_devices()
647
755
 
648
756
 
@@ -677,6 +785,8 @@ add_kernel_test(
677
785
  )
678
786
  add_kernel_test(TestStruct, kernel=test_return, name="test_return", dim=1, inputs=[], devices=devices)
679
787
  add_function_test(TestStruct, "test_nested_struct", test_nested_struct, devices=devices)
788
+ add_function_test(TestStruct, "test_struct_attribute_error", test_struct_attribute_error, devices=devices)
789
+ add_function_test(TestStruct, "test_struct_inheritance_error", test_struct_inheritance_error, devices=devices)
680
790
  add_function_test(TestStruct, "test_nested_array_struct", test_nested_array_struct, devices=devices)
681
791
  add_function_test(TestStruct, "test_convert_to_device", test_convert_to_device, devices=devices)
682
792
  add_function_test(TestStruct, "test_nested_empty_struct", test_nested_empty_struct, devices=devices)
@@ -727,6 +837,12 @@ add_kernel_test(
727
837
  )
728
838
 
729
839
  add_function_test(TestStruct, "test_struct_array_hash", test_struct_array_hash, devices=None)
840
+ add_function_test(
841
+ TestStruct, "test_struct_array_gc_requires_grad_toggle", test_struct_array_gc_requires_grad_toggle, devices=devices
842
+ )
843
+ add_function_test(
844
+ TestStruct, "test_struct_array_gc_direct_assignment", test_struct_array_gc_direct_assignment, devices=devices
845
+ )
730
846
 
731
847
 
732
848
  if __name__ == "__main__":