warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/conf.py +3 -4
- examples/env/env_ant.py +1 -1
- examples/env/env_cartpole.py +1 -1
- examples/env/env_humanoid.py +1 -1
- examples/example_dem.py +28 -26
- examples/example_diffray.py +37 -30
- examples/example_fluid.py +7 -3
- examples/example_jacobian_ik.py +1 -1
- examples/example_mesh_intersect.py +10 -7
- examples/example_nvdb.py +3 -3
- examples/example_render_opengl.py +19 -10
- examples/example_sim_cartpole.py +9 -5
- examples/example_sim_cloth.py +29 -25
- examples/example_sim_fk_grad.py +2 -2
- examples/example_sim_fk_grad_torch.py +3 -3
- examples/example_sim_grad_bounce.py +11 -8
- examples/example_sim_grad_cloth.py +12 -9
- examples/example_sim_granular.py +2 -2
- examples/example_sim_granular_collision_sdf.py +13 -13
- examples/example_sim_neo_hookean.py +3 -3
- examples/example_sim_particle_chain.py +2 -2
- examples/example_sim_quadruped.py +8 -5
- examples/example_sim_rigid_chain.py +8 -5
- examples/example_sim_rigid_contact.py +13 -10
- examples/example_sim_rigid_fem.py +2 -2
- examples/example_sim_rigid_gyroscopic.py +2 -2
- examples/example_sim_rigid_kinematics.py +1 -1
- examples/example_sim_trajopt.py +3 -2
- examples/fem/example_apic_fluid.py +5 -7
- examples/fem/example_diffusion_mgpu.py +18 -16
- warp/__init__.py +3 -2
- warp/bin/warp.so +0 -0
- warp/build_dll.py +29 -9
- warp/builtins.py +206 -7
- warp/codegen.py +58 -38
- warp/config.py +3 -1
- warp/context.py +234 -128
- warp/fem/__init__.py +2 -2
- warp/fem/cache.py +2 -1
- warp/fem/field/nodal_field.py +18 -17
- warp/fem/geometry/hexmesh.py +11 -6
- warp/fem/geometry/quadmesh_2d.py +16 -12
- warp/fem/geometry/tetmesh.py +19 -8
- warp/fem/geometry/trimesh_2d.py +18 -7
- warp/fem/integrate.py +341 -196
- warp/fem/quadrature/__init__.py +1 -1
- warp/fem/quadrature/pic_quadrature.py +138 -53
- warp/fem/quadrature/quadrature.py +81 -9
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_space.py +169 -51
- warp/fem/space/grid_2d_function_space.py +2 -2
- warp/fem/space/grid_3d_function_space.py +2 -2
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +9 -6
- warp/fem/space/quadmesh_2d_function_space.py +2 -2
- warp/fem/space/shape/cube_shape_function.py +27 -15
- warp/fem/space/shape/square_shape_function.py +29 -18
- warp/fem/space/tetmesh_function_space.py +2 -2
- warp/fem/space/topology.py +10 -0
- warp/fem/space/trimesh_2d_function_space.py +2 -2
- warp/fem/utils.py +10 -5
- warp/native/array.h +49 -8
- warp/native/builtin.h +31 -14
- warp/native/cuda_util.cpp +8 -3
- warp/native/cuda_util.h +1 -0
- warp/native/exports.h +1177 -1108
- warp/native/intersect.h +4 -4
- warp/native/intersect_adj.h +8 -8
- warp/native/mat.h +65 -6
- warp/native/mesh.h +126 -5
- warp/native/quat.h +28 -4
- warp/native/vec.h +76 -14
- warp/native/warp.cu +1 -6
- warp/render/render_opengl.py +261 -109
- warp/sim/import_mjcf.py +13 -7
- warp/sim/import_urdf.py +14 -14
- warp/sim/inertia.py +17 -18
- warp/sim/model.py +67 -67
- warp/sim/render.py +1 -1
- warp/sparse.py +6 -6
- warp/stubs.py +19 -81
- warp/tape.py +1 -1
- warp/tests/__main__.py +3 -6
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +102 -106
- warp/tests/test_arithmetic.py +39 -40
- warp/tests/test_array.py +46 -48
- warp/tests/test_array_reduce.py +25 -19
- warp/tests/test_atomic.py +62 -26
- warp/tests/test_bool.py +16 -11
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +9 -12
- warp/tests/test_closest_point_edge_edge.py +53 -57
- warp/tests/test_codegen.py +164 -134
- warp/tests/test_compile_consts.py +13 -19
- warp/tests/test_conditional.py +30 -32
- warp/tests/test_copy.py +9 -12
- warp/tests/test_ctypes.py +90 -98
- warp/tests/test_dense.py +20 -14
- warp/tests/test_devices.py +34 -35
- warp/tests/test_dlpack.py +74 -75
- warp/tests/test_examples.py +215 -97
- warp/tests/test_fabricarray.py +15 -21
- warp/tests/test_fast_math.py +14 -11
- warp/tests/test_fem.py +280 -97
- warp/tests/test_fp16.py +19 -15
- warp/tests/test_func.py +177 -194
- warp/tests/test_generics.py +71 -77
- warp/tests/test_grad.py +83 -32
- warp/tests/test_grad_customs.py +7 -9
- warp/tests/test_hash_grid.py +6 -10
- warp/tests/test_import.py +9 -23
- warp/tests/test_indexedarray.py +19 -21
- warp/tests/test_intersect.py +15 -9
- warp/tests/test_large.py +17 -19
- warp/tests/test_launch.py +14 -17
- warp/tests/test_lerp.py +63 -63
- warp/tests/test_lvalue.py +84 -35
- warp/tests/test_marching_cubes.py +9 -13
- warp/tests/test_mat.py +388 -3004
- warp/tests/test_mat_lite.py +9 -12
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +10 -11
- warp/tests/test_matmul.py +104 -100
- warp/tests/test_matmul_lite.py +72 -98
- warp/tests/test_mesh.py +35 -32
- warp/tests/test_mesh_query_aabb.py +18 -25
- warp/tests/test_mesh_query_point.py +39 -23
- warp/tests/test_mesh_query_ray.py +9 -21
- warp/tests/test_mlp.py +8 -9
- warp/tests/test_model.py +89 -93
- warp/tests/test_modules_lite.py +15 -25
- warp/tests/test_multigpu.py +87 -114
- warp/tests/test_noise.py +10 -12
- warp/tests/test_operators.py +14 -21
- warp/tests/test_options.py +10 -11
- warp/tests/test_pinned.py +16 -18
- warp/tests/test_print.py +16 -20
- warp/tests/test_quat.py +121 -88
- warp/tests/test_rand.py +12 -13
- warp/tests/test_reload.py +27 -32
- warp/tests/test_rounding.py +7 -10
- warp/tests/test_runlength_encode.py +105 -106
- warp/tests/test_smoothstep.py +8 -9
- warp/tests/test_snippet.py +13 -22
- warp/tests/test_sparse.py +30 -29
- warp/tests/test_spatial.py +179 -174
- warp/tests/test_streams.py +100 -107
- warp/tests/test_struct.py +98 -67
- warp/tests/test_tape.py +11 -17
- warp/tests/test_torch.py +89 -86
- warp/tests/test_transient_module.py +9 -12
- warp/tests/test_types.py +328 -50
- warp/tests/test_utils.py +217 -218
- warp/tests/test_vec.py +133 -2133
- warp/tests/test_vec_lite.py +8 -11
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +391 -382
- warp/tests/test_volume_write.py +122 -135
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +291 -0
- warp/tests/{test_base.py → unittest_utils.py} +138 -25
- warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
- warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
- warp/thirdparty/unittest_parallel.py +257 -54
- warp/types.py +119 -98
- warp/utils.py +14 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -239
- warp/tests/test_conditional_unequal_types_kernels.py +0 -14
- warp/tests/test_coverage.py +0 -38
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/tests/test_math.py
CHANGED
|
@@ -5,13 +5,13 @@
|
|
|
5
5
|
# distribution of this software and related documentation without an express
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
|
|
8
|
-
from typing import NamedTuple
|
|
9
8
|
import unittest
|
|
9
|
+
from typing import NamedTuple
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
|
|
13
13
|
import warp as wp
|
|
14
|
-
from warp.tests.
|
|
14
|
+
from warp.tests.unittest_utils import *
|
|
15
15
|
|
|
16
16
|
wp.init()
|
|
17
17
|
|
|
@@ -176,19 +176,18 @@ def test_mat_type(test, device):
|
|
|
176
176
|
raise ValueError("mat to string error")
|
|
177
177
|
|
|
178
178
|
|
|
179
|
-
|
|
180
|
-
|
|
179
|
+
devices = get_test_devices()
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class TestMath(unittest.TestCase):
|
|
183
|
+
pass
|
|
181
184
|
|
|
182
|
-
class TestMath(parent):
|
|
183
|
-
pass
|
|
184
185
|
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
return TestMath
|
|
186
|
+
add_function_test(TestMath, "test_scalar_math", test_scalar_math, devices=devices)
|
|
187
|
+
add_function_test(TestMath, "test_vec_type", test_vec_type, devices=devices)
|
|
188
|
+
add_function_test(TestMath, "test_mat_type", test_mat_type, devices=devices)
|
|
189
189
|
|
|
190
190
|
|
|
191
191
|
if __name__ == "__main__":
|
|
192
192
|
wp.build.clear_kernel_cache()
|
|
193
|
-
_ = register(unittest.TestCase)
|
|
194
193
|
unittest.main(verbosity=2)
|
warp/tests/test_matmul.py
CHANGED
|
@@ -1,11 +1,21 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
2
8
|
import unittest
|
|
3
9
|
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
4
12
|
import warp as wp
|
|
5
|
-
from warp.tests.
|
|
13
|
+
from warp.tests.unittest_utils import *
|
|
6
14
|
|
|
7
15
|
wp.init()
|
|
8
16
|
|
|
17
|
+
from warp.context import runtime # noqa: E402
|
|
18
|
+
|
|
9
19
|
|
|
10
20
|
class gemm_test_bed_runner:
|
|
11
21
|
def __init__(self, dtype, device):
|
|
@@ -21,63 +31,54 @@ class gemm_test_bed_runner:
|
|
|
21
31
|
np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
|
|
22
32
|
dtype=self.dtype,
|
|
23
33
|
device=self.device,
|
|
24
|
-
requires_grad=True
|
|
34
|
+
requires_grad=True,
|
|
25
35
|
)
|
|
26
36
|
B = wp.array2d(
|
|
27
37
|
np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
|
|
28
38
|
dtype=self.dtype,
|
|
29
39
|
device=self.device,
|
|
30
|
-
requires_grad=True
|
|
40
|
+
requires_grad=True,
|
|
31
41
|
)
|
|
32
42
|
C = wp.array2d(
|
|
33
43
|
np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
|
|
34
44
|
dtype=self.dtype,
|
|
35
45
|
device=self.device,
|
|
36
|
-
requires_grad=True
|
|
46
|
+
requires_grad=True,
|
|
37
47
|
)
|
|
38
|
-
D = wp.array2d(
|
|
39
|
-
np.zeros((m, n)),
|
|
40
|
-
dtype=self.dtype,
|
|
41
|
-
device=self.device,
|
|
42
|
-
requires_grad=True)
|
|
48
|
+
D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
|
|
43
49
|
else:
|
|
44
50
|
A = wp.array3d(
|
|
45
51
|
np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
|
|
46
52
|
dtype=self.dtype,
|
|
47
53
|
device=self.device,
|
|
48
|
-
requires_grad=True
|
|
54
|
+
requires_grad=True,
|
|
49
55
|
)
|
|
50
56
|
B = wp.array3d(
|
|
51
57
|
np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
|
|
52
58
|
dtype=self.dtype,
|
|
53
59
|
device=self.device,
|
|
54
|
-
requires_grad=True
|
|
60
|
+
requires_grad=True,
|
|
55
61
|
)
|
|
56
62
|
C = wp.array3d(
|
|
57
63
|
np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
|
|
58
64
|
dtype=self.dtype,
|
|
59
65
|
device=self.device,
|
|
60
|
-
requires_grad=True
|
|
61
|
-
)
|
|
62
|
-
D = wp.array3d(
|
|
63
|
-
np.zeros((batch_count, m, n)),
|
|
64
|
-
dtype=self.dtype,
|
|
65
|
-
device=self.device,
|
|
66
|
-
requires_grad=True
|
|
66
|
+
requires_grad=True,
|
|
67
67
|
)
|
|
68
|
+
D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
|
|
68
69
|
return A, B, C, D
|
|
69
70
|
|
|
70
71
|
def run_and_verify(self, m, n, k, batch_count, alpha, beta):
|
|
71
72
|
A, B, C, D = self.alloc(m, n, k, batch_count)
|
|
72
73
|
ones = wp.zeros_like(D)
|
|
73
74
|
ones.fill_(1.0)
|
|
74
|
-
|
|
75
|
+
|
|
75
76
|
if batch_count == 1:
|
|
76
77
|
tape = wp.Tape()
|
|
77
78
|
with tape:
|
|
78
79
|
wp.matmul(A, B, C, D, alpha, beta, False, self.device)
|
|
79
|
-
tape.backward(grads={D
|
|
80
|
-
|
|
80
|
+
tape.backward(grads={D: ones})
|
|
81
|
+
|
|
81
82
|
D_np = alpha * (A.numpy() @ B.numpy()) + beta * C.numpy()
|
|
82
83
|
assert np.array_equal(D_np, D.numpy())
|
|
83
84
|
|
|
@@ -89,8 +90,8 @@ class gemm_test_bed_runner:
|
|
|
89
90
|
tape = wp.Tape()
|
|
90
91
|
with tape:
|
|
91
92
|
wp.batched_matmul(A, B, C, D, alpha, beta, False, self.device)
|
|
92
|
-
tape.backward(grads={D
|
|
93
|
-
|
|
93
|
+
tape.backward(grads={D: ones})
|
|
94
|
+
|
|
94
95
|
D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
|
|
95
96
|
assert np.array_equal(D_np, D.numpy())
|
|
96
97
|
|
|
@@ -132,75 +133,45 @@ class gemm_test_bed_runner_transpose:
|
|
|
132
133
|
np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
|
|
133
134
|
dtype=self.dtype,
|
|
134
135
|
device=self.device,
|
|
135
|
-
requires_grad=True
|
|
136
|
+
requires_grad=True,
|
|
136
137
|
)
|
|
137
138
|
B = wp.array2d(
|
|
138
139
|
np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
|
|
139
140
|
dtype=self.dtype,
|
|
140
141
|
device=self.device,
|
|
141
|
-
requires_grad=True
|
|
142
|
+
requires_grad=True,
|
|
142
143
|
)
|
|
143
144
|
C = wp.array2d(
|
|
144
145
|
np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
|
|
145
146
|
dtype=self.dtype,
|
|
146
147
|
device=self.device,
|
|
147
|
-
requires_grad=True
|
|
148
|
-
)
|
|
149
|
-
D = wp.array2d(
|
|
150
|
-
np.zeros((m, n)),
|
|
151
|
-
dtype=self.dtype,
|
|
152
|
-
device=self.device,
|
|
153
|
-
requires_grad=True
|
|
154
|
-
)
|
|
155
|
-
AT = wp.array2d(
|
|
156
|
-
A.numpy().transpose([1, 0]),
|
|
157
|
-
dtype=self.dtype,
|
|
158
|
-
device=self.device,
|
|
159
|
-
requires_grad=True
|
|
160
|
-
)
|
|
161
|
-
BT = wp.array2d(
|
|
162
|
-
B.numpy().transpose([1, 0]),
|
|
163
|
-
dtype=self.dtype,
|
|
164
|
-
device=self.device,
|
|
165
|
-
requires_grad=True
|
|
148
|
+
requires_grad=True,
|
|
166
149
|
)
|
|
150
|
+
D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
|
|
151
|
+
AT = wp.array2d(A.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
|
|
152
|
+
BT = wp.array2d(B.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
|
|
167
153
|
else:
|
|
168
154
|
A = wp.array3d(
|
|
169
155
|
np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
|
|
170
156
|
dtype=self.dtype,
|
|
171
157
|
device=self.device,
|
|
172
|
-
requires_grad=True
|
|
158
|
+
requires_grad=True,
|
|
173
159
|
)
|
|
174
160
|
B = wp.array3d(
|
|
175
161
|
np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
|
|
176
162
|
dtype=self.dtype,
|
|
177
163
|
device=self.device,
|
|
178
|
-
requires_grad=True
|
|
164
|
+
requires_grad=True,
|
|
179
165
|
)
|
|
180
166
|
C = wp.array3d(
|
|
181
167
|
np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
|
|
182
168
|
dtype=self.dtype,
|
|
183
169
|
device=self.device,
|
|
184
|
-
requires_grad=True
|
|
185
|
-
)
|
|
186
|
-
D = wp.array3d(
|
|
187
|
-
np.zeros((batch_count, m, n)),
|
|
188
|
-
dtype=self.dtype,
|
|
189
|
-
device=self.device,
|
|
190
|
-
requires_grad=True
|
|
191
|
-
)
|
|
192
|
-
AT = wp.array3d(
|
|
193
|
-
A.numpy().transpose([0, 2, 1]),
|
|
194
|
-
dtype=self.dtype,
|
|
195
|
-
device=self.device,
|
|
196
|
-
requires_grad=True
|
|
197
|
-
)
|
|
198
|
-
BT = wp.array3d(
|
|
199
|
-
B.numpy().transpose([0, 2, 1]),
|
|
200
|
-
dtype=self.dtype,
|
|
201
|
-
device=self.device,
|
|
202
|
-
requires_grad=True
|
|
170
|
+
requires_grad=True,
|
|
203
171
|
)
|
|
172
|
+
D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
|
|
173
|
+
AT = wp.array3d(A.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
|
|
174
|
+
BT = wp.array3d(B.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
|
|
204
175
|
return A, B, C, D, AT, BT
|
|
205
176
|
|
|
206
177
|
def run_and_verify(self, m, n, k, batch_count, alpha, beta):
|
|
@@ -219,17 +190,17 @@ class gemm_test_bed_runner_transpose:
|
|
|
219
190
|
ones3.fill_(1.0)
|
|
220
191
|
|
|
221
192
|
if batch_count == 1:
|
|
222
|
-
ATT1 = AT1.transpose([1, 0])
|
|
193
|
+
ATT1 = AT1.transpose([1, 0])
|
|
223
194
|
BTT1 = BT1.transpose([1, 0])
|
|
224
|
-
ATT2 = AT2.transpose([1, 0])
|
|
195
|
+
ATT2 = AT2.transpose([1, 0])
|
|
225
196
|
BTT2 = BT2.transpose([1, 0])
|
|
226
197
|
tape = wp.Tape()
|
|
227
198
|
with tape:
|
|
228
199
|
wp.matmul(A, BTT1, C1, D1, alpha, beta, False, self.device)
|
|
229
200
|
wp.matmul(ATT1, B, C2, D2, alpha, beta, False, self.device)
|
|
230
201
|
wp.matmul(ATT2, BTT2, C3, D3, alpha, beta, False, self.device)
|
|
231
|
-
tape.backward(grads={D1
|
|
232
|
-
|
|
202
|
+
tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
|
|
203
|
+
|
|
233
204
|
D_np = alpha * (A.numpy() @ B.numpy()) + beta * C1.numpy()
|
|
234
205
|
assert np.array_equal(D_np, D1.numpy())
|
|
235
206
|
assert np.array_equal(D_np, D2.numpy())
|
|
@@ -240,7 +211,7 @@ class gemm_test_bed_runner_transpose:
|
|
|
240
211
|
adj_C_np = beta * ones1.numpy()
|
|
241
212
|
|
|
242
213
|
else:
|
|
243
|
-
ATT1 = AT1.transpose([0, 2, 1])
|
|
214
|
+
ATT1 = AT1.transpose([0, 2, 1])
|
|
244
215
|
BTT1 = BT1.transpose([0, 2, 1])
|
|
245
216
|
ATT2 = AT2.transpose([0, 2, 1])
|
|
246
217
|
BTT2 = BT2.transpose([0, 2, 1])
|
|
@@ -249,8 +220,8 @@ class gemm_test_bed_runner_transpose:
|
|
|
249
220
|
wp.batched_matmul(A, BTT1, C1, D1, alpha, beta, False, self.device)
|
|
250
221
|
wp.batched_matmul(ATT1, B, C2, D2, alpha, beta, False, self.device)
|
|
251
222
|
wp.batched_matmul(ATT2, BTT2, C3, D3, alpha, beta, False, self.device)
|
|
252
|
-
tape.backward(grads={D1
|
|
253
|
-
|
|
223
|
+
tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
|
|
224
|
+
|
|
254
225
|
D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C1.numpy()
|
|
255
226
|
assert np.array_equal(D_np, D1.numpy())
|
|
256
227
|
assert np.array_equal(D_np, D2.numpy())
|
|
@@ -288,11 +259,13 @@ def test_f16(test, device):
|
|
|
288
259
|
gemm_test_bed_runner_transpose(wp.float16, device).run()
|
|
289
260
|
|
|
290
261
|
|
|
262
|
+
@unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
|
|
291
263
|
def test_f32(test, device):
|
|
292
264
|
gemm_test_bed_runner(wp.float32, device).run()
|
|
293
265
|
gemm_test_bed_runner_transpose(wp.float32, device).run()
|
|
294
266
|
|
|
295
267
|
|
|
268
|
+
@unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
|
|
296
269
|
def test_f64(test, device):
|
|
297
270
|
gemm_test_bed_runner(wp.float64, device).run()
|
|
298
271
|
gemm_test_bed_runner_transpose(wp.float64, device).run()
|
|
@@ -304,6 +277,7 @@ def matrix_sum_kernel(arr: wp.array2d(dtype=float), loss: wp.array(dtype=float))
|
|
|
304
277
|
wp.atomic_add(loss, 0, arr[i, j])
|
|
305
278
|
|
|
306
279
|
|
|
280
|
+
@unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
|
|
307
281
|
def test_tape(test, device):
|
|
308
282
|
rng = np.random.default_rng(42)
|
|
309
283
|
low = -4.5
|
|
@@ -331,6 +305,7 @@ def test_tape(test, device):
|
|
|
331
305
|
|
|
332
306
|
tape.backward(loss=loss)
|
|
333
307
|
A_grad = A.grad.numpy()
|
|
308
|
+
tape.reset()
|
|
334
309
|
|
|
335
310
|
# test adjoint
|
|
336
311
|
D.grad = wp.array2d(np.ones((m, n)), dtype=float, device=device)
|
|
@@ -342,6 +317,7 @@ def test_tape(test, device):
|
|
|
342
317
|
assert_array_equal(A.grad, wp.zeros_like(A))
|
|
343
318
|
|
|
344
319
|
|
|
320
|
+
@unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
|
|
345
321
|
def test_operator(test, device):
|
|
346
322
|
rng = np.random.default_rng(42)
|
|
347
323
|
low = -4.5
|
|
@@ -377,6 +353,7 @@ def test_operator(test, device):
|
|
|
377
353
|
assert_array_equal(A.grad, wp.zeros_like(A))
|
|
378
354
|
|
|
379
355
|
|
|
356
|
+
@unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
|
|
380
357
|
def test_large_batch_count(test, device):
|
|
381
358
|
rng = np.random.default_rng(42)
|
|
382
359
|
low = -4.5
|
|
@@ -386,31 +363,38 @@ def test_large_batch_count(test, device):
|
|
|
386
363
|
k = 4
|
|
387
364
|
batch_count = 65535 * 2 + int(65535 / 2)
|
|
388
365
|
A = wp.array3d(
|
|
389
|
-
np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
|
|
366
|
+
np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
|
|
367
|
+
dtype=float,
|
|
368
|
+
device=device,
|
|
369
|
+
requires_grad=True,
|
|
390
370
|
)
|
|
391
371
|
B = wp.array3d(
|
|
392
|
-
np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
|
|
372
|
+
np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
|
|
373
|
+
dtype=float,
|
|
374
|
+
device=device,
|
|
375
|
+
requires_grad=True,
|
|
393
376
|
)
|
|
394
377
|
C = wp.array3d(
|
|
395
|
-
np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
378
|
+
np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
|
|
379
|
+
dtype=float,
|
|
380
|
+
device=device,
|
|
381
|
+
requires_grad=True,
|
|
399
382
|
)
|
|
383
|
+
D = wp.array3d(np.zeros((batch_count, m, n)), dtype=float, device=device, requires_grad=True)
|
|
400
384
|
ones = wp.zeros_like(D)
|
|
401
385
|
ones.fill_(1.0)
|
|
402
386
|
|
|
403
387
|
alpha = 1.0
|
|
404
388
|
beta = 1.0
|
|
405
|
-
|
|
389
|
+
|
|
406
390
|
tape = wp.Tape()
|
|
407
391
|
with tape:
|
|
408
392
|
wp.batched_matmul(A, B, C, D, alpha=alpha, beta=beta, allow_tf32x3_arith=False, device=device)
|
|
409
|
-
tape.backward(grads={D
|
|
393
|
+
tape.backward(grads={D: ones})
|
|
410
394
|
|
|
411
395
|
D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
|
|
412
396
|
assert np.array_equal(D_np, D.numpy())
|
|
413
|
-
|
|
397
|
+
|
|
414
398
|
adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
|
|
415
399
|
adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
|
|
416
400
|
adj_C_np = beta * ones.numpy()
|
|
@@ -420,30 +404,50 @@ def test_large_batch_count(test, device):
|
|
|
420
404
|
assert np.array_equal(adj_C_np, C.grad.numpy())
|
|
421
405
|
|
|
422
406
|
|
|
423
|
-
def
|
|
424
|
-
|
|
407
|
+
def test_adjoint_accumulation(test, device):
|
|
408
|
+
a_np = np.ones(shape=(2,3))
|
|
409
|
+
b_np = np.ones(shape=(3,2))
|
|
410
|
+
c_np = np.zeros(shape=(2,2))
|
|
411
|
+
d_np = np.zeros(shape=(2,2))
|
|
425
412
|
|
|
426
|
-
|
|
427
|
-
|
|
413
|
+
a_wp = wp.from_numpy(a_np, dtype=float, requires_grad=True)
|
|
414
|
+
b_wp = wp.from_numpy(b_np, dtype=float, requires_grad=True)
|
|
415
|
+
c_wp = wp.from_numpy(c_np, dtype=float, requires_grad=True)
|
|
416
|
+
d1_wp = wp.from_numpy(d_np, dtype=float, requires_grad=True)
|
|
417
|
+
d2_wp = wp.from_numpy(d_np, dtype=float, requires_grad=True)
|
|
428
418
|
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
419
|
+
tape = wp.Tape()
|
|
420
|
+
|
|
421
|
+
with tape:
|
|
422
|
+
wp.matmul(a_wp, b_wp, c_wp, d1_wp, alpha=1.0, beta=1.0)
|
|
423
|
+
wp.matmul(a_wp, b_wp, d1_wp, d2_wp, alpha=1.0, beta=1.0)
|
|
424
|
+
|
|
425
|
+
d_grad = wp.zeros_like(d2_wp)
|
|
426
|
+
d_grad.fill_(1.)
|
|
427
|
+
grads = {d2_wp : d_grad}
|
|
428
|
+
tape.backward(grads=grads)
|
|
429
|
+
|
|
430
|
+
assert np.array_equal(a_wp.grad.numpy(), 4.0 * np.ones(shape=(2,3)))
|
|
431
|
+
assert np.array_equal(b_wp.grad.numpy(), 4.0 * np.ones(shape=(3,2)))
|
|
432
|
+
assert np.array_equal(c_wp.grad.numpy(), np.ones(shape=(2,2)))
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
devices = get_test_devices()
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
class TestMatmul(unittest.TestCase):
|
|
439
|
+
pass
|
|
432
440
|
|
|
433
|
-
if runtime.core.is_cutlass_enabled():
|
|
434
|
-
# add_function_test(TestMatmul, "test_f16", test_f16, devices=devices)
|
|
435
|
-
add_function_test(TestMatmul, "test_f32", test_f32, devices=devices)
|
|
436
|
-
add_function_test(TestMatmul, "test_f64", test_f64, devices=devices)
|
|
437
|
-
add_function_test(TestMatmul, "test_tape", test_tape, devices=devices)
|
|
438
|
-
add_function_test(TestMatmul, "test_operator", test_operator, devices=devices)
|
|
439
|
-
add_function_test(TestMatmul, "test_large_batch_count", test_large_batch_count, devices=devices)
|
|
440
|
-
else:
|
|
441
|
-
print("Skipping matmul tests because CUTLASS is not supported in this build")
|
|
442
441
|
|
|
443
|
-
|
|
442
|
+
# add_function_test(TestMatmul, "test_f16", test_f16, devices=devices)
|
|
443
|
+
add_function_test(TestMatmul, "test_f32", test_f32, devices=devices)
|
|
444
|
+
add_function_test(TestMatmul, "test_f64", test_f64, devices=devices)
|
|
445
|
+
add_function_test(TestMatmul, "test_tape", test_tape, devices=devices)
|
|
446
|
+
add_function_test(TestMatmul, "test_operator", test_operator, devices=devices)
|
|
447
|
+
add_function_test(TestMatmul, "test_large_batch_count", test_large_batch_count, devices=devices)
|
|
448
|
+
add_function_test(TestMatmul, "test_adjoint_accumulation", test_adjoint_accumulation, devices=devices)
|
|
444
449
|
|
|
445
450
|
|
|
446
451
|
if __name__ == "__main__":
|
|
447
452
|
wp.build.clear_kernel_cache()
|
|
448
|
-
_ = register(unittest.TestCase)
|
|
449
453
|
unittest.main(verbosity=2, failfast=False)
|