warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.1__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/tests/test_map.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import unittest
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
import warp as wp
|
|
21
|
+
import warp.context
|
|
22
|
+
import warp.tests.aux_test_name_clash1 as name_clash_module_1
|
|
23
|
+
import warp.tests.aux_test_name_clash2 as name_clash_module_2
|
|
24
|
+
from warp.tests.unittest_utils import add_function_test, assert_np_equal, get_cuda_test_devices, get_test_devices
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@wp.struct
|
|
28
|
+
class MyStruct:
|
|
29
|
+
a: float
|
|
30
|
+
b: float
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@wp.func
|
|
34
|
+
def add(x: MyStruct, y: MyStruct):
|
|
35
|
+
s = MyStruct()
|
|
36
|
+
s.a = x.a + y.a
|
|
37
|
+
s.b = x.b + y.b
|
|
38
|
+
return s
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@wp.func
|
|
42
|
+
def create_struct(a: float, b: float):
|
|
43
|
+
s = MyStruct()
|
|
44
|
+
s.a = a
|
|
45
|
+
s.b = b
|
|
46
|
+
return s
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_mixed_inputs(test, device):
|
|
50
|
+
conds = wp.array([True, False, True], dtype=bool, device=device)
|
|
51
|
+
out = wp.map(wp.where, conds, 1.0, 0.0)
|
|
52
|
+
assert isinstance(out, wp.array)
|
|
53
|
+
expected = np.array([1.0, 0.0, 1.0], dtype=np.float32)
|
|
54
|
+
assert_np_equal(out.numpy(), expected)
|
|
55
|
+
|
|
56
|
+
rot = wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), -wp.half_pi)
|
|
57
|
+
tf = wp.transform(wp.vec3(1.0, 2.0, 3.0), rot)
|
|
58
|
+
points = wp.array([(1.0, 2.0, 3.0), (4.0, 5.0, 6.0)], dtype=wp.vec3, device=device)
|
|
59
|
+
out = wp.map(wp.transform_point, tf, points)
|
|
60
|
+
assert isinstance(out, wp.array)
|
|
61
|
+
expected = wp.array([(2.0, 5.0, 1.0), (5.0, 8.0, -2.0)], dtype=wp.vec3, device=device)
|
|
62
|
+
assert_np_equal(out.numpy(), expected.numpy(), tol=1e-6)
|
|
63
|
+
|
|
64
|
+
floats = wp.array([-10.0, -5.0, 0.5, 2.0, 8.0], dtype=wp.float32, device=device)
|
|
65
|
+
out = wp.map(wp.clamp, floats, -0.5, 0.5)
|
|
66
|
+
assert isinstance(out, wp.array)
|
|
67
|
+
expected = np.array([-0.5, -0.5, 0.5, 0.5, 0.5])
|
|
68
|
+
assert_np_equal(out.numpy(), expected)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def test_lambda(test, device):
|
|
72
|
+
a1 = wp.array(np.arange(10, dtype=np.float32), device=device)
|
|
73
|
+
out = wp.map(lambda a: a + 2.0, a1)
|
|
74
|
+
assert isinstance(out, wp.array)
|
|
75
|
+
expected = np.array(np.arange(10) + 2.0, dtype=np.float32)
|
|
76
|
+
assert_np_equal(out.numpy(), expected)
|
|
77
|
+
|
|
78
|
+
out = wp.map(lambda a: create_struct(a, a), a1)
|
|
79
|
+
assert isinstance(out, wp.array)
|
|
80
|
+
out = out.list()
|
|
81
|
+
for i in range(10):
|
|
82
|
+
test.assertEqual(out[i].a, i)
|
|
83
|
+
test.assertEqual(out[i].b, i)
|
|
84
|
+
|
|
85
|
+
local_var = 3.0
|
|
86
|
+
out2 = wp.map(lambda a: a + local_var, a1)
|
|
87
|
+
assert isinstance(out2, wp.array)
|
|
88
|
+
expected = np.array(np.arange(10) + local_var, dtype=np.float32)
|
|
89
|
+
assert_np_equal(out2.numpy(), expected)
|
|
90
|
+
|
|
91
|
+
local_var = 3.0
|
|
92
|
+
|
|
93
|
+
@wp.func
|
|
94
|
+
def my_func(a: float):
|
|
95
|
+
return a * local_var
|
|
96
|
+
|
|
97
|
+
out = wp.empty_like(a1)
|
|
98
|
+
wp.map(lambda a: (a + local_var, my_func(a)), a1, out=[out, out2])
|
|
99
|
+
expected = np.array(np.arange(10) + local_var, dtype=np.float32)
|
|
100
|
+
assert_np_equal(out.numpy(), expected)
|
|
101
|
+
expected = np.array(np.arange(10) * local_var, dtype=np.float32)
|
|
102
|
+
assert_np_equal(out2.numpy(), expected)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def test_multiple_return_values(test, device):
|
|
106
|
+
@wp.func
|
|
107
|
+
def multiple_return(a: float):
|
|
108
|
+
return a + 2.0, a + 3.0, wp.vec3(a, a, a)
|
|
109
|
+
|
|
110
|
+
a1 = wp.array(np.arange(10, dtype=np.float32), device=device)
|
|
111
|
+
out = wp.map(multiple_return, a1)
|
|
112
|
+
assert isinstance(out, list)
|
|
113
|
+
out = [o.list() for o in out]
|
|
114
|
+
for i in range(10):
|
|
115
|
+
test.assertEqual(out[0][i], i + 2.0)
|
|
116
|
+
test.assertEqual(out[1][i], i + 3.0)
|
|
117
|
+
test.assertEqual(out[2][i].x, i)
|
|
118
|
+
test.assertEqual(out[2][i].y, i)
|
|
119
|
+
test.assertEqual(out[2][i].z, i)
|
|
120
|
+
|
|
121
|
+
out = wp.map(lambda a: multiple_return(a), a1)
|
|
122
|
+
assert isinstance(out, list)
|
|
123
|
+
out = [o.list() for o in out]
|
|
124
|
+
for i in range(10):
|
|
125
|
+
test.assertEqual(out[0][i], i + 2.0)
|
|
126
|
+
test.assertEqual(out[1][i], i + 3.0)
|
|
127
|
+
test.assertEqual(out[2][i].x, i)
|
|
128
|
+
test.assertEqual(out[2][i].y, i)
|
|
129
|
+
test.assertEqual(out[2][i].z, i)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def test_custom_struct_operator(test, device):
|
|
133
|
+
x1, x2 = MyStruct(), MyStruct()
|
|
134
|
+
x1.a = 1.0
|
|
135
|
+
x1.b = 2.0
|
|
136
|
+
x2.a = 3.0
|
|
137
|
+
x2.b = 4.0
|
|
138
|
+
y1, y2 = MyStruct(), MyStruct()
|
|
139
|
+
y1.a = 10.0
|
|
140
|
+
y1.b = 20.0
|
|
141
|
+
y2.a = 30.0
|
|
142
|
+
y2.b = 40.0
|
|
143
|
+
xs = wp.array([x1, x2], dtype=MyStruct, device=device)
|
|
144
|
+
ys = wp.array([y1, y2], dtype=MyStruct, device=device)
|
|
145
|
+
zs = wp.map(add, xs, ys)
|
|
146
|
+
assert isinstance(zs, wp.array)
|
|
147
|
+
zs = zs.list()
|
|
148
|
+
test.assertEqual(zs[0].a, 11.0)
|
|
149
|
+
test.assertEqual(zs[0].b, 22.0)
|
|
150
|
+
test.assertEqual(zs[1].a, 33.0)
|
|
151
|
+
test.assertEqual(zs[1].b, 44.0)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def test_name_clash(test, device):
|
|
155
|
+
vec5 = wp.types.vector(5, dtype=wp.float32)
|
|
156
|
+
|
|
157
|
+
@wp.func
|
|
158
|
+
def name_clash_structs_args_func(
|
|
159
|
+
s1: name_clash_module_1.SameStruct,
|
|
160
|
+
s2: name_clash_module_2.SameStruct,
|
|
161
|
+
d1: name_clash_module_1.DifferentStruct,
|
|
162
|
+
d2: name_clash_module_2.DifferentStruct,
|
|
163
|
+
):
|
|
164
|
+
return vec5(s1.x, s2.x, d1.v, d2.v[0], d2.v[1])
|
|
165
|
+
|
|
166
|
+
s1 = name_clash_module_1.SameStruct()
|
|
167
|
+
s2 = name_clash_module_2.SameStruct()
|
|
168
|
+
d1 = name_clash_module_1.DifferentStruct()
|
|
169
|
+
d2 = name_clash_module_2.DifferentStruct()
|
|
170
|
+
s1.x = 1.0
|
|
171
|
+
s2.x = 2.0
|
|
172
|
+
d1.v = 3.0
|
|
173
|
+
d2.v = wp.vec2(4.0, 5.0)
|
|
174
|
+
s1s = wp.array([s1], dtype=name_clash_module_1.SameStruct, device=device)
|
|
175
|
+
s2s = wp.array([s2], dtype=name_clash_module_2.SameStruct, device=device)
|
|
176
|
+
d1s = wp.array([d1], dtype=name_clash_module_1.DifferentStruct, device=device)
|
|
177
|
+
d2s = wp.array([d2], dtype=name_clash_module_2.DifferentStruct, device=device)
|
|
178
|
+
out = wp.map(name_clash_structs_args_func, s1s, s2s, d1s, d2s)
|
|
179
|
+
assert isinstance(out, wp.array)
|
|
180
|
+
assert_np_equal(out.numpy(), np.array([[1, 2, 3, 4, 5]], dtype=np.float32))
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def test_gradient(test, device):
|
|
184
|
+
@wp.func
|
|
185
|
+
def my_func(a: float):
|
|
186
|
+
return 2.0 * a - 10.0
|
|
187
|
+
|
|
188
|
+
a = wp.array(np.arange(10, dtype=np.float32), requires_grad=True, device=device)
|
|
189
|
+
assert a.grad is not None
|
|
190
|
+
tape = wp.Tape()
|
|
191
|
+
with tape:
|
|
192
|
+
out = wp.map(my_func, a)
|
|
193
|
+
assert isinstance(out, wp.array)
|
|
194
|
+
assert out.grad is not None
|
|
195
|
+
out.grad.fill_(1.0)
|
|
196
|
+
tape.backward()
|
|
197
|
+
expected = np.full(10, 2.0, dtype=np.float32)
|
|
198
|
+
assert_np_equal(a.grad.numpy(), expected)
|
|
199
|
+
a.grad *= 2.0
|
|
200
|
+
assert_np_equal(a.grad.numpy(), expected * 2.0)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def test_array_ops(test, device):
|
|
204
|
+
a = wp.array(np.arange(10, dtype=np.float32), device=device)
|
|
205
|
+
b = wp.array(np.arange(1, 11, dtype=np.float32), device=device)
|
|
206
|
+
a_np = a.numpy()
|
|
207
|
+
b_np = b.numpy()
|
|
208
|
+
assert_np_equal((+a).numpy(), a_np)
|
|
209
|
+
assert_np_equal((-a).numpy(), -a_np)
|
|
210
|
+
assert_np_equal((a + b).numpy(), a_np + b_np)
|
|
211
|
+
assert_np_equal((a + 2.0).numpy(), a_np + 2.0)
|
|
212
|
+
assert_np_equal((a - b).numpy(), a_np - b_np)
|
|
213
|
+
assert_np_equal((a - 2.0).numpy(), a_np - 2.0)
|
|
214
|
+
assert_np_equal((2.0 - a).numpy(), 2.0 - a_np)
|
|
215
|
+
assert_np_equal((a * b).numpy(), a_np * b_np)
|
|
216
|
+
assert_np_equal((2.0 * a).numpy(), 2.0 * a_np)
|
|
217
|
+
assert_np_equal((a * 2.0).numpy(), a_np * 2.0)
|
|
218
|
+
np.testing.assert_allclose((a**b).numpy(), a_np**b_np, rtol=1.5e-7)
|
|
219
|
+
np.testing.assert_allclose((a**2.0).numpy(), a_np**2.0)
|
|
220
|
+
assert_np_equal((a / b).numpy(), a_np / b_np)
|
|
221
|
+
assert_np_equal((a / 2.0).numpy(), a_np / 2.0)
|
|
222
|
+
assert_np_equal((a // b).numpy(), a_np // b_np)
|
|
223
|
+
assert_np_equal((a // 2.0).numpy(), a_np // 2.0)
|
|
224
|
+
assert_np_equal((2.0 / b).numpy(), 2.0 / b_np)
|
|
225
|
+
ai = wp.array(np.arange(10, dtype=np.int32), device=device)
|
|
226
|
+
bi = wp.array(np.arange(1, 11, dtype=np.int32), device=device)
|
|
227
|
+
ai_np = ai.numpy()
|
|
228
|
+
bi_np = bi.numpy()
|
|
229
|
+
div = ai / bi
|
|
230
|
+
# XXX note in Warp div on int32 is integer division
|
|
231
|
+
test.assertEqual(div.dtype, wp.int32)
|
|
232
|
+
assert_np_equal(div.numpy(), ai_np // bi_np)
|
|
233
|
+
|
|
234
|
+
@wp.func
|
|
235
|
+
def make_vec(a: float):
|
|
236
|
+
return wp.vec3(a, a + 1.0, a + 2.0)
|
|
237
|
+
|
|
238
|
+
vecs_a = wp.map(make_vec, a)
|
|
239
|
+
vecs_b = wp.map(make_vec, b)
|
|
240
|
+
assert isinstance(vecs_a, wp.array)
|
|
241
|
+
assert isinstance(vecs_b, wp.array)
|
|
242
|
+
vecs_a_np = vecs_a.numpy()
|
|
243
|
+
vecs_b_np = vecs_b.numpy()
|
|
244
|
+
assert_np_equal((-vecs_a).numpy(), -vecs_a_np)
|
|
245
|
+
assert_np_equal((vecs_a + vecs_b).numpy(), vecs_a_np + vecs_b_np)
|
|
246
|
+
assert_np_equal((vecs_a * 2.0).numpy(), vecs_a_np * 2.0)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def test_indexedarrays(test, device):
|
|
250
|
+
arr = wp.array(data=np.arange(10, dtype=np.float32), device=device)
|
|
251
|
+
indices = wp.array([1, 3, 5, 7, 9], dtype=int, device=device)
|
|
252
|
+
iarr = wp.indexedarray1d(arr, [indices])
|
|
253
|
+
out = wp.map(lambda x: x * 10.0, iarr)
|
|
254
|
+
assert isinstance(out, wp.array)
|
|
255
|
+
expected = np.array([10.0, 30.0, 50.0, 70.0, 90.0], dtype=np.float32)
|
|
256
|
+
assert_np_equal(out.numpy(), expected)
|
|
257
|
+
wp.map(lambda x: x * 10.0, iarr, out=iarr)
|
|
258
|
+
assert isinstance(iarr, wp.indexedarray)
|
|
259
|
+
assert_np_equal(iarr.numpy(), expected)
|
|
260
|
+
|
|
261
|
+
newarr = 10.0 * iarr
|
|
262
|
+
assert isinstance(newarr, wp.array)
|
|
263
|
+
expected = np.array([100.0, 300.0, 500.0, 700.0, 900.0], dtype=np.float32)
|
|
264
|
+
assert_np_equal(newarr.numpy(), expected)
|
|
265
|
+
iarr += 1.0
|
|
266
|
+
assert isinstance(iarr, wp.indexedarray)
|
|
267
|
+
expected = np.array([11.0, 31.0, 51.0, 71.0, 91.0], dtype=np.float32)
|
|
268
|
+
assert_np_equal(iarr.numpy(), expected)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def test_broadcasting(test, device):
|
|
272
|
+
a = wp.array(np.zeros((1, 3, 1, 4), dtype=np.float32), device=device)
|
|
273
|
+
b = wp.array(np.ones((5, 4), dtype=np.float32), device=device)
|
|
274
|
+
out = wp.map(wp.add, a, b)
|
|
275
|
+
assert isinstance(out, wp.array)
|
|
276
|
+
test.assertEqual(out.shape, (1, 3, 5, 4))
|
|
277
|
+
expected = np.ones((1, 3, 5, 4), dtype=np.float32)
|
|
278
|
+
assert_np_equal(out.numpy(), expected)
|
|
279
|
+
|
|
280
|
+
out = wp.map(wp.add, b, a)
|
|
281
|
+
assert isinstance(out, wp.array)
|
|
282
|
+
expected = np.ones((1, 3, 5, 4), dtype=np.float32)
|
|
283
|
+
assert_np_equal(out.numpy(), expected)
|
|
284
|
+
|
|
285
|
+
c = wp.array(np.ones((2, 3, 5, 4), dtype=np.float32), device=device)
|
|
286
|
+
out = wp.map(lambda a, b, c: a + b + c, a, b, c)
|
|
287
|
+
assert isinstance(out, wp.array)
|
|
288
|
+
test.assertEqual(out.shape, (2, 3, 5, 4))
|
|
289
|
+
expected = np.ones((2, 3, 5, 4), dtype=np.float32) * 2.0
|
|
290
|
+
assert_np_equal(out.numpy(), expected)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def test_input_validity(test, device):
|
|
294
|
+
@wp.func
|
|
295
|
+
def empty_function(f: float):
|
|
296
|
+
pass
|
|
297
|
+
|
|
298
|
+
a1 = wp.empty(3, dtype=wp.float32)
|
|
299
|
+
with test.assertRaisesRegex(
|
|
300
|
+
TypeError,
|
|
301
|
+
"The provided function must return a value$",
|
|
302
|
+
):
|
|
303
|
+
wp.map(empty_function, a1)
|
|
304
|
+
|
|
305
|
+
@wp.func
|
|
306
|
+
def unary_function(f: float):
|
|
307
|
+
return 2.0 * f
|
|
308
|
+
|
|
309
|
+
with test.assertRaisesRegex(
|
|
310
|
+
TypeError,
|
|
311
|
+
r"Number of input arguments \(2\) does not match expected number of function arguments \(1\)$",
|
|
312
|
+
):
|
|
313
|
+
wp.map(unary_function, a1, a1)
|
|
314
|
+
|
|
315
|
+
@wp.func
|
|
316
|
+
def int_function(i: int):
|
|
317
|
+
return 5.0 * float(i)
|
|
318
|
+
|
|
319
|
+
with test.assertRaisesRegex(
|
|
320
|
+
TypeError,
|
|
321
|
+
'Incorrect input provided for argument "i": received array of dtype float32, expected int$',
|
|
322
|
+
):
|
|
323
|
+
wp.map(int_function, a1)
|
|
324
|
+
|
|
325
|
+
i1 = wp.zeros((3, 2, 1), dtype=wp.float32)
|
|
326
|
+
i2 = wp.ones((3, 3, 2), dtype=wp.float32)
|
|
327
|
+
with test.assertRaisesRegex(
|
|
328
|
+
ValueError,
|
|
329
|
+
r"Shapes \(3, 2, 1\) and \(3, 3, 2\) are not broadcastable$",
|
|
330
|
+
):
|
|
331
|
+
wp.map(wp.add, i1, i2)
|
|
332
|
+
|
|
333
|
+
xs = wp.zeros(3, dtype=wp.float32)
|
|
334
|
+
ys = wp.zeros(5, dtype=wp.float32)
|
|
335
|
+
with test.assertRaisesRegex(
|
|
336
|
+
ValueError,
|
|
337
|
+
r"Shapes \(3,\) and \(5,\) are not broadcastable$",
|
|
338
|
+
):
|
|
339
|
+
wp.map(lambda a, b, c: a + b * c, 0.0, xs, ys)
|
|
340
|
+
|
|
341
|
+
with test.assertRaisesRegex(
|
|
342
|
+
ValueError,
|
|
343
|
+
"map requires at least one warp.array input$",
|
|
344
|
+
):
|
|
345
|
+
wp.map(lambda a, b, c: a * b * c, 2.0, 0.4, [5.0])
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def test_output_validity(test, device):
|
|
349
|
+
xs = wp.zeros(3, dtype=wp.float32)
|
|
350
|
+
ys = wp.ones(3, dtype=wp.float32)
|
|
351
|
+
out = wp.empty(2, dtype=wp.float32)
|
|
352
|
+
with test.assertRaisesRegex(
|
|
353
|
+
TypeError,
|
|
354
|
+
r"Output array shape \(2,\) does not match expected shape \(3,\)$",
|
|
355
|
+
):
|
|
356
|
+
wp.map(wp.sub, xs, ys, out=out)
|
|
357
|
+
|
|
358
|
+
out = wp.empty((2, 3), dtype=wp.float32)
|
|
359
|
+
with test.assertRaisesRegex(
|
|
360
|
+
TypeError,
|
|
361
|
+
r"Invalid output provided, expected 2 Warp arrays with shape \(3,\) and dtypes \(float32, float32\)$",
|
|
362
|
+
):
|
|
363
|
+
wp.map(lambda x, y: (x, y), xs, ys, out=out)
|
|
364
|
+
|
|
365
|
+
out = wp.empty(3, dtype=wp.int32)
|
|
366
|
+
with test.assertRaisesRegex(
|
|
367
|
+
TypeError,
|
|
368
|
+
"Output array dtype int32 does not match expected dtype float32$",
|
|
369
|
+
):
|
|
370
|
+
wp.map(lambda x, y: x - y, xs, ys, out=out)
|
|
371
|
+
|
|
372
|
+
out = wp.empty((3, 1), dtype=wp.float32)
|
|
373
|
+
with test.assertRaisesRegex(
|
|
374
|
+
TypeError,
|
|
375
|
+
r"Output array shape \(3, 1\) does not match expected shape \(3,\)$",
|
|
376
|
+
):
|
|
377
|
+
wp.map(wp.mul, xs, ys, out=out)
|
|
378
|
+
|
|
379
|
+
out1 = wp.empty(3, dtype=wp.float32)
|
|
380
|
+
out2 = wp.empty(3, dtype=wp.int32)
|
|
381
|
+
with test.assertRaisesRegex(
|
|
382
|
+
TypeError,
|
|
383
|
+
"Output array 1 dtype int32 does not match expected dtype float32$",
|
|
384
|
+
):
|
|
385
|
+
wp.map(lambda x, y: (x, y), xs, ys, out=[out1, out2])
|
|
386
|
+
|
|
387
|
+
out1 = wp.empty(3, dtype=wp.float32)
|
|
388
|
+
out2 = wp.empty(3, dtype=wp.float32)
|
|
389
|
+
out3 = wp.empty(1, dtype=wp.float32)
|
|
390
|
+
with test.assertRaisesRegex(
|
|
391
|
+
TypeError,
|
|
392
|
+
r"Number of provided output arrays \(3\) does not match expected number of function outputs \(2\)$",
|
|
393
|
+
):
|
|
394
|
+
wp.map(lambda x, y: (x, y), xs, ys, out=[out1, out2, out3])
|
|
395
|
+
|
|
396
|
+
out1 = wp.empty(3, dtype=wp.float32)
|
|
397
|
+
out2 = wp.empty((3, 1), dtype=wp.float32)
|
|
398
|
+
with test.assertRaisesRegex(
|
|
399
|
+
TypeError,
|
|
400
|
+
r"Output array 1 shape \(3, 1\) does not match expected shape \(3,\)$",
|
|
401
|
+
):
|
|
402
|
+
wp.map(lambda x, y: (x, y), xs, ys, out=[out1, out2])
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def test_kernel_creation(test, device):
|
|
406
|
+
a = wp.array(np.arange(10, dtype=np.float32), device=device)
|
|
407
|
+
kernel = wp.map(lambda a: a + 2.0, a, return_kernel=True)
|
|
408
|
+
test.assertIsInstance(kernel, wp.Kernel)
|
|
409
|
+
|
|
410
|
+
b = wp.zeros(20)
|
|
411
|
+
out = wp.empty_like(b)
|
|
412
|
+
wp.launch(kernel, dim=len(b), inputs=[b], outputs=[out])
|
|
413
|
+
expected = np.full(20, 2.0, dtype=np.float32)
|
|
414
|
+
assert_np_equal(out.numpy(), expected)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def test_graph_capture(test, device):
|
|
418
|
+
assert warp.context.runtime.driver_version is not None
|
|
419
|
+
if warp.context.runtime.driver_version < (12, 3):
|
|
420
|
+
test.skipTest("Module loading during CUDA graph capture is not supported on driver versions < 12.3")
|
|
421
|
+
a_np = np.arange(10, dtype=np.float32)
|
|
422
|
+
b_np = np.arange(1, 11, dtype=np.float32)
|
|
423
|
+
a = wp.array(a_np, device=device)
|
|
424
|
+
b = wp.array(b_np, device=device)
|
|
425
|
+
with wp.ScopedCapture(device, force_module_load=False) as capture:
|
|
426
|
+
out = wp.map(lambda x, y: wp.abs(2.0 * x - y), a, b)
|
|
427
|
+
out = wp.map(wp.sin, out)
|
|
428
|
+
assert isinstance(out, wp.array)
|
|
429
|
+
out *= 2.0
|
|
430
|
+
expected = np.array(2.0 * np.sin(np.abs(a_np * 2.0 - b_np)), dtype=np.float32)
|
|
431
|
+
assert capture.graph is not None
|
|
432
|
+
wp.capture_launch(capture.graph)
|
|
433
|
+
assert_np_equal(out.numpy(), expected, tol=1e-6)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def test_renamed_warp_module(test, device):
|
|
437
|
+
import warp as uncommon_name
|
|
438
|
+
|
|
439
|
+
@wp.func
|
|
440
|
+
def my_func(a: float):
|
|
441
|
+
return uncommon_name.abs(2.0 * a - 10.0)
|
|
442
|
+
|
|
443
|
+
a = wp.array(np.arange(10, dtype=np.float32), device=device)
|
|
444
|
+
b = wp.array(np.arange(1, 11, dtype=np.float32), device=device)
|
|
445
|
+
out = wp.map(lambda x, y: uncommon_name.abs(2.0 * x - y), a, b)
|
|
446
|
+
assert isinstance(out, wp.array)
|
|
447
|
+
expected = np.array(np.abs(a.numpy() * 2.0 - b.numpy()), dtype=np.float32)
|
|
448
|
+
assert_np_equal(out.numpy(), expected, tol=1e-6)
|
|
449
|
+
out = wp.map(my_func, a)
|
|
450
|
+
assert isinstance(out, wp.array)
|
|
451
|
+
expected = np.array(np.abs(a.numpy() * 2.0 - 10.0), dtype=np.float32)
|
|
452
|
+
assert_np_equal(out.numpy(), expected, tol=1e-6)
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
devices = get_test_devices("basic")
|
|
456
|
+
cuda_test_devices = get_cuda_test_devices()
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
class TestMap(unittest.TestCase):
|
|
460
|
+
pass
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
add_function_test(TestMap, "test_mixed_inputs", test_mixed_inputs, devices=devices)
|
|
464
|
+
add_function_test(TestMap, "test_lambda", test_lambda, devices=devices)
|
|
465
|
+
add_function_test(TestMap, "test_multiple_return_values", test_multiple_return_values, devices=devices)
|
|
466
|
+
add_function_test(TestMap, "test_custom_struct_operator", test_custom_struct_operator, devices=devices)
|
|
467
|
+
add_function_test(TestMap, "test_name_clash", test_name_clash, devices=devices)
|
|
468
|
+
add_function_test(TestMap, "test_gradient", test_gradient, devices=devices)
|
|
469
|
+
add_function_test(TestMap, "test_array_ops", test_array_ops, devices=devices)
|
|
470
|
+
add_function_test(TestMap, "test_indexedarrays", test_indexedarrays, devices=devices)
|
|
471
|
+
add_function_test(TestMap, "test_broadcasting", test_broadcasting, devices=devices)
|
|
472
|
+
add_function_test(TestMap, "test_input_validity", test_input_validity, devices=devices)
|
|
473
|
+
add_function_test(TestMap, "test_output_validity", test_output_validity, devices=devices)
|
|
474
|
+
add_function_test(TestMap, "test_kernel_creation", test_kernel_creation, devices=devices)
|
|
475
|
+
add_function_test(TestMap, "test_graph_capture", test_graph_capture, devices=cuda_test_devices)
|
|
476
|
+
add_function_test(TestMap, "test_renamed_warp_module", test_renamed_warp_module, devices=devices)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
if __name__ == "__main__":
|
|
480
|
+
wp.clear_kernel_cache()
|
|
481
|
+
unittest.main(verbosity=2)
|
warp/tests/test_mat.py
CHANGED
|
@@ -114,23 +114,6 @@ def test_anon_constructor_error_invalid_arg_count(test, device):
|
|
|
114
114
|
wp.launch(kernel, dim=1, inputs=[], device=device)
|
|
115
115
|
|
|
116
116
|
|
|
117
|
-
def test_anon_xform_constructor_error_type_mismatch(test, device):
|
|
118
|
-
@wp.kernel
|
|
119
|
-
def kernel():
|
|
120
|
-
wp.matrix(wp.vec3(1.0, 2.0, 3.0), wp.quat(0.0, 0.0, 0.0, 1.0), wp.vec3(2.0, 2.0, 2.0), wp.float64)
|
|
121
|
-
|
|
122
|
-
with test.assertRaisesRegex(
|
|
123
|
-
RuntimeError,
|
|
124
|
-
r"all values used to initialize this transformation matrix are expected to be of the type `float64`$",
|
|
125
|
-
):
|
|
126
|
-
wp.launch(
|
|
127
|
-
kernel,
|
|
128
|
-
dim=1,
|
|
129
|
-
inputs=[],
|
|
130
|
-
device=device,
|
|
131
|
-
)
|
|
132
|
-
|
|
133
|
-
|
|
134
117
|
def test_tpl_constructor_error_incompatible_sizes(test, device):
|
|
135
118
|
@wp.kernel
|
|
136
119
|
def kernel():
|
|
@@ -220,7 +203,7 @@ def test_quat_constructor(test, device, dtype, register_kernels=False):
|
|
|
220
203
|
outcomponents: wp.array(dtype=wptype),
|
|
221
204
|
outcomponents_alt: wp.array(dtype=wptype),
|
|
222
205
|
):
|
|
223
|
-
m =
|
|
206
|
+
m = wp.transform_compose(p[0], r[0], s[0])
|
|
224
207
|
|
|
225
208
|
R = wp.transpose(wp.quat_to_matrix(r[0]))
|
|
226
209
|
c0 = s[0][0] * R[0]
|
|
@@ -2242,6 +2225,27 @@ def test_mat_array_sub_inplace(test, device):
|
|
|
2242
2225
|
assert_np_equal(x.grad.numpy(), np.array([[[-1.0, -1.0], [-1.0, -1.0]]], dtype=float))
|
|
2243
2226
|
|
|
2244
2227
|
|
|
2228
|
+
@wp.kernel
|
|
2229
|
+
def scalar_mat_div(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.mat22)):
|
|
2230
|
+
i = wp.tid()
|
|
2231
|
+
y[i] = 1.0 / x[i]
|
|
2232
|
+
|
|
2233
|
+
|
|
2234
|
+
def test_scalar_mat_div(test, device):
|
|
2235
|
+
x = wp.array((wp.mat22(1.0, 2.0, 4.0, 8.0),), dtype=wp.mat22, requires_grad=True, device=device)
|
|
2236
|
+
y = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
|
|
2237
|
+
|
|
2238
|
+
tape = wp.Tape()
|
|
2239
|
+
with tape:
|
|
2240
|
+
wp.launch(scalar_mat_div, 1, inputs=(x,), outputs=(y,), device=device)
|
|
2241
|
+
|
|
2242
|
+
y.grad = wp.ones_like(y)
|
|
2243
|
+
tape.backward()
|
|
2244
|
+
|
|
2245
|
+
assert_np_equal(y.numpy(), np.array((((1.0, 0.5), (0.25, 0.125)),), dtype=float))
|
|
2246
|
+
assert_np_equal(x.grad.numpy(), np.array((((-1.0, -0.25), (-0.0625, -0.015625)),), dtype=float))
|
|
2247
|
+
|
|
2248
|
+
|
|
2245
2249
|
devices = get_test_devices()
|
|
2246
2250
|
|
|
2247
2251
|
|
|
@@ -2323,12 +2327,6 @@ add_function_test(
|
|
|
2323
2327
|
test_anon_constructor_error_invalid_arg_count,
|
|
2324
2328
|
devices=devices,
|
|
2325
2329
|
)
|
|
2326
|
-
add_function_test(
|
|
2327
|
-
TestMat,
|
|
2328
|
-
"test_anon_xform_constructor_error_type_mismatch",
|
|
2329
|
-
test_anon_xform_constructor_error_type_mismatch,
|
|
2330
|
-
devices=devices,
|
|
2331
|
-
)
|
|
2332
2330
|
add_function_test(
|
|
2333
2331
|
TestMat,
|
|
2334
2332
|
"test_tpl_constructor_error_incompatible_sizes",
|
|
@@ -2379,6 +2377,7 @@ add_function_test(TestMat, "test_mat_add_inplace", test_mat_add_inplace, devices
|
|
|
2379
2377
|
add_function_test(TestMat, "test_mat_sub_inplace", test_mat_sub_inplace, devices=devices)
|
|
2380
2378
|
add_function_test(TestMat, "test_mat_array_add_inplace", test_mat_array_add_inplace, devices=devices)
|
|
2381
2379
|
add_function_test(TestMat, "test_mat_array_sub_inplace", test_mat_array_sub_inplace, devices=devices)
|
|
2380
|
+
add_function_test(TestMat, "test_scalar_mat_div", test_scalar_mat_div, devices=devices)
|
|
2382
2381
|
|
|
2383
2382
|
|
|
2384
2383
|
if __name__ == "__main__":
|
warp/tests/test_quat.py
CHANGED
|
@@ -2047,7 +2047,6 @@ def quat_extract_subscript(x: wp.array(dtype=wp.quat), y: wp.array(dtype=float))
|
|
|
2047
2047
|
y[tid] = b
|
|
2048
2048
|
|
|
2049
2049
|
|
|
2050
|
-
""" TODO: rhs attribute indexing
|
|
2051
2050
|
@wp.kernel
|
|
2052
2051
|
def quat_extract_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=float)):
|
|
2053
2052
|
tid = wp.tid()
|
|
@@ -2055,7 +2054,6 @@ def quat_extract_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=float))
|
|
|
2055
2054
|
a = x[tid]
|
|
2056
2055
|
b = a.x + float(2.0) * a.y + 3.0 * a.z + 4.0 * a.w
|
|
2057
2056
|
y[tid] = b
|
|
2058
|
-
"""
|
|
2059
2057
|
|
|
2060
2058
|
|
|
2061
2059
|
def test_quat_extract(test, device):
|
|
@@ -2074,7 +2072,7 @@ def test_quat_extract(test, device):
|
|
|
2074
2072
|
assert_np_equal(x.grad.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=float))
|
|
2075
2073
|
|
|
2076
2074
|
run(quat_extract_subscript)
|
|
2077
|
-
|
|
2075
|
+
run(quat_extract_attribute)
|
|
2078
2076
|
|
|
2079
2077
|
|
|
2080
2078
|
@wp.kernel
|
|
@@ -2163,7 +2161,6 @@ def quat_array_extract_subscript(x: wp.array2d(dtype=wp.quat), y: wp.array2d(dty
|
|
|
2163
2161
|
y[i, j] = 1.0 * a + 2.0 * b + 3.0 * c + 4.0 * d
|
|
2164
2162
|
|
|
2165
2163
|
|
|
2166
|
-
""" TODO: rhs attribute indexing
|
|
2167
2164
|
@wp.kernel
|
|
2168
2165
|
def quat_array_extract_attribute(x: wp.array2d(dtype=wp.quat), y: wp.array2d(dtype=float)):
|
|
2169
2166
|
i, j = wp.tid()
|
|
@@ -2172,7 +2169,6 @@ def quat_array_extract_attribute(x: wp.array2d(dtype=wp.quat), y: wp.array2d(dty
|
|
|
2172
2169
|
c = x[i, j].z
|
|
2173
2170
|
d = x[i, j].w
|
|
2174
2171
|
y[i, j] = 1.0 * a + 2.0 * b + 3.0 * c + 4.0 * d
|
|
2175
|
-
"""
|
|
2176
2172
|
|
|
2177
2173
|
|
|
2178
2174
|
def test_quat_array_extract(test, device):
|
|
@@ -2191,7 +2187,7 @@ def test_quat_array_extract(test, device):
|
|
|
2191
2187
|
assert_np_equal(x.grad.numpy(), np.array([[[1.0, 2.0, 3.0, 4.0]]], dtype=float))
|
|
2192
2188
|
|
|
2193
2189
|
run(quat_array_extract_subscript)
|
|
2194
|
-
|
|
2190
|
+
run(quat_array_extract_attribute)
|
|
2195
2191
|
|
|
2196
2192
|
|
|
2197
2193
|
@wp.kernel
|
|
@@ -2249,7 +2245,6 @@ def quat_add_inplace_subscript(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.
|
|
|
2249
2245
|
y[i] = a
|
|
2250
2246
|
|
|
2251
2247
|
|
|
2252
|
-
""" TODO: rhs attribute indexing
|
|
2253
2248
|
@wp.kernel
|
|
2254
2249
|
def quat_add_inplace_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
|
|
2255
2250
|
i = wp.tid()
|
|
@@ -2263,7 +2258,6 @@ def quat_add_inplace_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.
|
|
|
2263
2258
|
a.w += 4.0 * b.w
|
|
2264
2259
|
|
|
2265
2260
|
y[i] = a
|
|
2266
|
-
"""
|
|
2267
2261
|
|
|
2268
2262
|
|
|
2269
2263
|
def test_quat_add_inplace(test, device):
|
|
@@ -2282,7 +2276,7 @@ def test_quat_add_inplace(test, device):
|
|
|
2282
2276
|
assert_np_equal(x.grad.numpy(), np.array([[1.0, 2.0, 3.0, 4.0]], dtype=float))
|
|
2283
2277
|
|
|
2284
2278
|
run(quat_add_inplace_subscript)
|
|
2285
|
-
|
|
2279
|
+
run(quat_add_inplace_attribute)
|
|
2286
2280
|
|
|
2287
2281
|
|
|
2288
2282
|
@wp.kernel
|
|
@@ -2300,7 +2294,6 @@ def quat_sub_inplace_subscript(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.
|
|
|
2300
2294
|
y[i] = a
|
|
2301
2295
|
|
|
2302
2296
|
|
|
2303
|
-
""" TODO: rhs attribute indexing
|
|
2304
2297
|
@wp.kernel
|
|
2305
2298
|
def quat_sub_inplace_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
|
|
2306
2299
|
i = wp.tid()
|
|
@@ -2314,7 +2307,6 @@ def quat_sub_inplace_attribute(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.
|
|
|
2314
2307
|
a.w -= 4.0 * b.w
|
|
2315
2308
|
|
|
2316
2309
|
y[i] = a
|
|
2317
|
-
"""
|
|
2318
2310
|
|
|
2319
2311
|
|
|
2320
2312
|
def test_quat_sub_inplace(test, device):
|
|
@@ -2333,7 +2325,7 @@ def test_quat_sub_inplace(test, device):
|
|
|
2333
2325
|
assert_np_equal(x.grad.numpy(), np.array([[-1.0, -2.0, -3.0, -4.0]], dtype=float))
|
|
2334
2326
|
|
|
2335
2327
|
run(quat_sub_inplace_subscript)
|
|
2336
|
-
|
|
2328
|
+
run(quat_sub_inplace_attribute)
|
|
2337
2329
|
|
|
2338
2330
|
|
|
2339
2331
|
@wp.kernel
|
|
@@ -2358,7 +2350,6 @@ def test_quat_array_add_inplace(test, device):
|
|
|
2358
2350
|
assert_np_equal(x.grad.numpy(), np.array([[1.0, 1.0, 1.0, 1.0]], dtype=float))
|
|
2359
2351
|
|
|
2360
2352
|
|
|
2361
|
-
""" TODO: quat negation operator
|
|
2362
2353
|
@wp.kernel
|
|
2363
2354
|
def quat_array_sub_inplace(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
|
|
2364
2355
|
i = wp.tid()
|
|
@@ -2379,7 +2370,28 @@ def test_quat_array_sub_inplace(test, device):
|
|
|
2379
2370
|
|
|
2380
2371
|
assert_np_equal(y.numpy(), np.array([[-1.0, -1.0, -1.0, -1.0]], dtype=float))
|
|
2381
2372
|
assert_np_equal(x.grad.numpy(), np.array([[-1.0, -1.0, -1.0, -1.0]], dtype=float))
|
|
2382
|
-
|
|
2373
|
+
|
|
2374
|
+
|
|
2375
|
+
@wp.kernel
|
|
2376
|
+
def scalar_quat_div(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
|
|
2377
|
+
i = wp.tid()
|
|
2378
|
+
y[i] = 1.0 / x[i]
|
|
2379
|
+
|
|
2380
|
+
|
|
2381
|
+
def test_scalar_quat_div(test, device):
|
|
2382
|
+
x = wp.array((wp.quat(1.0, 2.0, 4.0, 8.0),), dtype=wp.quat, requires_grad=True, device=device)
|
|
2383
|
+
y = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
|
|
2384
|
+
|
|
2385
|
+
tape = wp.Tape()
|
|
2386
|
+
with tape:
|
|
2387
|
+
wp.launch(scalar_quat_div, 1, inputs=(x,), outputs=(y,), device=device)
|
|
2388
|
+
|
|
2389
|
+
y.grad = wp.ones_like(y)
|
|
2390
|
+
tape.backward()
|
|
2391
|
+
|
|
2392
|
+
assert_np_equal(y.numpy(), np.array(((1.0, 0.5, 0.25, 0.125),), dtype=float))
|
|
2393
|
+
assert_np_equal(x.grad.numpy(), np.array(((-1.0, -0.25, -0.0625, -0.015625),), dtype=float))
|
|
2394
|
+
|
|
2383
2395
|
|
|
2384
2396
|
devices = get_test_devices()
|
|
2385
2397
|
|
|
@@ -2491,7 +2503,8 @@ add_function_test(TestQuat, "test_quat_array_assign", test_quat_array_assign, de
|
|
|
2491
2503
|
add_function_test(TestQuat, "test_quat_add_inplace", test_quat_add_inplace, devices=devices)
|
|
2492
2504
|
add_function_test(TestQuat, "test_quat_sub_inplace", test_quat_sub_inplace, devices=devices)
|
|
2493
2505
|
add_function_test(TestQuat, "test_quat_array_add_inplace", test_quat_array_add_inplace, devices=devices)
|
|
2494
|
-
|
|
2506
|
+
add_function_test(TestQuat, "test_quat_array_sub_inplace", test_quat_array_sub_inplace, devices=devices)
|
|
2507
|
+
add_function_test(TestQuat, "test_scalar_quat_div", test_scalar_quat_div, devices=devices)
|
|
2495
2508
|
|
|
2496
2509
|
|
|
2497
2510
|
if __name__ == "__main__":
|