warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/libwarp-clang.dylib +0 -0
  3. warp/bin/libwarp.dylib +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/tests/test_matmul.py DELETED
@@ -1,511 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import itertools
17
- import unittest
18
- from typing import Any
19
-
20
- import numpy as np
21
-
22
- import warp as wp
23
- from warp.tests.unittest_utils import *
24
-
25
- wp.init() # For wp.context.runtime.core.is_cutlass_enabled()
26
-
27
-
28
- class gemm_test_bed_runner:
29
- def __init__(self, dtype, device):
30
- self.dtype = dtype
31
- self.device = device
32
-
33
- def alloc(self, m, n, k, batch_count):
34
- rng = np.random.default_rng(42)
35
- low = -4.5
36
- high = 3.5
37
- if batch_count == 1:
38
- A = wp.array2d(
39
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
40
- dtype=self.dtype,
41
- device=self.device,
42
- requires_grad=True,
43
- )
44
- B = wp.array2d(
45
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
46
- dtype=self.dtype,
47
- device=self.device,
48
- requires_grad=True,
49
- )
50
- C = wp.array2d(
51
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
52
- dtype=self.dtype,
53
- device=self.device,
54
- requires_grad=True,
55
- )
56
- D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
57
- else:
58
- A = wp.array3d(
59
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
60
- dtype=self.dtype,
61
- device=self.device,
62
- requires_grad=True,
63
- )
64
- B = wp.array3d(
65
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
66
- dtype=self.dtype,
67
- device=self.device,
68
- requires_grad=True,
69
- )
70
- C = wp.array3d(
71
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
72
- dtype=self.dtype,
73
- device=self.device,
74
- requires_grad=True,
75
- )
76
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
77
- return A, B, C, D
78
-
79
- def run_and_verify(self, m, n, k, batch_count, alpha, beta):
80
- A, B, C, D = self.alloc(m, n, k, batch_count)
81
- ones = wp.zeros_like(D)
82
- ones.fill_(1.0)
83
-
84
- np_dtype = wp.types.warp_type_to_np_dtype[self.dtype]
85
-
86
- if batch_count == 1:
87
- tape = wp.Tape()
88
- with tape:
89
- wp.matmul(A, B, C, D, alpha, beta, False)
90
- tape.backward(grads={D: ones})
91
-
92
- D_np = alpha * np.matmul(A.numpy(), B.numpy(), dtype=np_dtype) + beta * C.numpy()
93
- assert_np_equal(D.numpy(), D_np)
94
-
95
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose(), dtype=np_dtype)
96
- adj_B_np = alpha * np.matmul(A.numpy().transpose(), ones.numpy(), dtype=np_dtype)
97
- adj_C_np = beta * ones.numpy()
98
-
99
- else:
100
- tape = wp.Tape()
101
- with tape:
102
- wp.batched_matmul(A, B, C, D, alpha, beta, False)
103
- tape.backward(grads={D: ones})
104
-
105
- D_np = alpha * np.matmul(A.numpy(), B.numpy(), dtype=np_dtype) + beta * C.numpy()
106
- assert_np_equal(D.numpy(), D_np)
107
-
108
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)), dtype=np_dtype)
109
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy(), dtype=np_dtype)
110
- adj_C_np = beta * ones.numpy()
111
-
112
- assert_np_equal(A.grad.numpy(), adj_A_np)
113
- assert_np_equal(B.grad.numpy(), adj_B_np)
114
- assert_np_equal(C.grad.numpy(), adj_C_np)
115
-
116
- def run(self):
117
- Ms = [16, 32, 64]
118
- Ns = [16, 32, 64]
119
- Ks = [16, 32, 64]
120
- batch_counts = [1, 4]
121
- betas = [0.0, 1.0]
122
- alpha = 1.0
123
-
124
- for batch_count, m, n, k, beta in itertools.product(batch_counts, Ms, Ns, Ks, betas):
125
- self.run_and_verify(m, n, k, batch_count, alpha, beta)
126
-
127
-
128
- class gemm_test_bed_runner_transpose:
129
- def __init__(self, dtype, device):
130
- self.dtype = dtype
131
- self.device = device
132
-
133
- def alloc(self, m, n, k, batch_count):
134
- rng = np.random.default_rng(42)
135
- low = -4.5
136
- high = 3.5
137
- if batch_count == 1:
138
- A = wp.array2d(
139
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
140
- dtype=self.dtype,
141
- device=self.device,
142
- requires_grad=True,
143
- )
144
- B = wp.array2d(
145
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
146
- dtype=self.dtype,
147
- device=self.device,
148
- requires_grad=True,
149
- )
150
- C = wp.array2d(
151
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
152
- dtype=self.dtype,
153
- device=self.device,
154
- requires_grad=True,
155
- )
156
- D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
157
- AT = wp.array2d(A.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
158
- BT = wp.array2d(B.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
159
- else:
160
- A = wp.array3d(
161
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
162
- dtype=self.dtype,
163
- device=self.device,
164
- requires_grad=True,
165
- )
166
- B = wp.array3d(
167
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
168
- dtype=self.dtype,
169
- device=self.device,
170
- requires_grad=True,
171
- )
172
- C = wp.array3d(
173
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
174
- dtype=self.dtype,
175
- device=self.device,
176
- requires_grad=True,
177
- )
178
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
179
- AT = wp.array3d(A.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
180
- BT = wp.array3d(B.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
181
- return A, B, C, D, AT, BT
182
-
183
- def run_and_verify(self, m, n, k, batch_count, alpha, beta):
184
- A, B, C1, D1, AT1, BT1 = self.alloc(m, n, k, batch_count)
185
- C2 = wp.clone(C1)
186
- C3 = wp.clone(C1)
187
- D2 = wp.clone(D1)
188
- D3 = wp.clone(D1)
189
- AT2 = wp.clone(AT1)
190
- BT2 = wp.clone(BT1)
191
- ones1 = wp.zeros_like(D1)
192
- ones1.fill_(1.0)
193
- ones2 = wp.zeros_like(D2)
194
- ones2.fill_(1.0)
195
- ones3 = wp.zeros_like(D3)
196
- ones3.fill_(1.0)
197
-
198
- np_dtype = wp.types.warp_type_to_np_dtype[self.dtype]
199
-
200
- if batch_count == 1:
201
- ATT1 = AT1.transpose([1, 0])
202
- BTT1 = BT1.transpose([1, 0])
203
- ATT2 = AT2.transpose([1, 0])
204
- BTT2 = BT2.transpose([1, 0])
205
- tape = wp.Tape()
206
- with tape:
207
- wp.matmul(A, BTT1, C1, D1, alpha, beta, False)
208
- wp.matmul(ATT1, B, C2, D2, alpha, beta, False)
209
- wp.matmul(ATT2, BTT2, C3, D3, alpha, beta, False)
210
- tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
211
-
212
- D_np = alpha * np.matmul(A.numpy(), B.numpy(), dtype=np_dtype) + beta * C1.numpy()
213
- assert_np_equal(D1.numpy(), D_np)
214
- assert_np_equal(D2.numpy(), D_np)
215
- assert_np_equal(D3.numpy(), D_np)
216
-
217
- adj_A_np = alpha * np.matmul(ones1.numpy(), B.numpy().transpose(), dtype=np_dtype)
218
- adj_B_np = alpha * np.matmul(A.numpy().transpose(), ones1.numpy(), dtype=np_dtype)
219
- adj_C_np = beta * ones1.numpy()
220
-
221
- else:
222
- ATT1 = AT1.transpose([0, 2, 1])
223
- BTT1 = BT1.transpose([0, 2, 1])
224
- ATT2 = AT2.transpose([0, 2, 1])
225
- BTT2 = BT2.transpose([0, 2, 1])
226
- tape = wp.Tape()
227
- with tape:
228
- wp.batched_matmul(A, BTT1, C1, D1, alpha, beta, False)
229
- wp.batched_matmul(ATT1, B, C2, D2, alpha, beta, False)
230
- wp.batched_matmul(ATT2, BTT2, C3, D3, alpha, beta, False)
231
- tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
232
-
233
- D_np = alpha * np.matmul(A.numpy(), B.numpy(), dtype=np_dtype) + beta * C1.numpy()
234
- assert_np_equal(D1.numpy(), D_np)
235
- assert_np_equal(D2.numpy(), D_np)
236
- assert_np_equal(D3.numpy(), D_np)
237
-
238
- adj_A_np = alpha * np.matmul(ones1.numpy(), B.numpy().transpose((0, 2, 1)), dtype=np_dtype)
239
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones1.numpy(), dtype=np_dtype)
240
- adj_C_np = beta * ones1.numpy()
241
-
242
- assert_np_equal(A.grad.numpy(), adj_A_np)
243
- assert_np_equal(ATT1.grad.numpy(), adj_A_np)
244
- assert_np_equal(ATT2.grad.numpy(), adj_A_np)
245
- assert_np_equal(B.grad.numpy(), adj_B_np)
246
- assert_np_equal(BTT1.grad.numpy(), adj_B_np)
247
- assert_np_equal(BTT2.grad.numpy(), adj_B_np)
248
- assert_np_equal(C1.grad.numpy(), adj_C_np)
249
- assert_np_equal(C2.grad.numpy(), adj_C_np)
250
- assert_np_equal(C3.grad.numpy(), adj_C_np)
251
-
252
- def run(self):
253
- m = 16
254
- n = 32
255
- k = 64
256
- batch_counts = [1, 4]
257
- beta = 1.0
258
- alpha = 1.0
259
-
260
- for batch_count in batch_counts:
261
- self.run_and_verify(m, n, k, batch_count, alpha, beta)
262
-
263
-
264
- # NOTE: F16 tests are slow due to the performance of the reference numpy F16 matmuls performed on CPU.
265
- def test_f16(test, device):
266
- gemm_test_bed_runner(wp.float16, device).run()
267
- gemm_test_bed_runner_transpose(wp.float16, device).run()
268
-
269
-
270
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
271
- def test_f32(test, device):
272
- gemm_test_bed_runner(wp.float32, device).run()
273
- gemm_test_bed_runner_transpose(wp.float32, device).run()
274
-
275
-
276
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
277
- def test_f64(test, device):
278
- gemm_test_bed_runner(wp.float64, device).run()
279
- gemm_test_bed_runner_transpose(wp.float64, device).run()
280
-
281
-
282
- @wp.kernel
283
- def matrix_sum_kernel(arr: wp.array2d(dtype=float), loss: wp.array(dtype=float)):
284
- i, j = wp.tid()
285
- wp.atomic_add(loss, 0, arr[i, j])
286
-
287
-
288
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
289
- def test_tape(test, device):
290
- rng = np.random.default_rng(42)
291
- low = -4.5
292
- high = 3.5
293
- m = 64
294
- n = 128
295
- k = 256
296
- A = wp.array2d(
297
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
298
- )
299
- B = wp.array2d(
300
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
301
- )
302
- C = wp.array2d(
303
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))), dtype=float, device=device, requires_grad=True
304
- )
305
- D = wp.array2d(np.zeros((m, n)), dtype=float, device=device, requires_grad=True)
306
- loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
307
-
308
- # test tape
309
- tape = wp.Tape()
310
- with tape:
311
- wp.matmul(A, B, C, D)
312
- wp.launch(matrix_sum_kernel, dim=(m, n), inputs=[D, loss], device=device)
313
-
314
- tape.backward(loss=loss)
315
- A_grad = A.grad.numpy()
316
- tape.reset()
317
-
318
- # test adjoint
319
- D.grad = wp.ones((m, n), dtype=float, device=device)
320
- wp.adj_matmul(A, B, C, A.grad, B.grad, C.grad, D.grad)
321
- assert_np_equal(A_grad, A.grad.numpy())
322
-
323
- # test zero
324
- tape.zero()
325
- assert_array_equal(A.grad, wp.zeros_like(A))
326
-
327
-
328
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
329
- def test_operator(test, device):
330
- rng = np.random.default_rng(42)
331
- low = -4.5
332
- high = 3.5
333
- m = 64
334
- n = 128
335
- k = 256
336
- A = wp.array2d(
337
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
338
- )
339
- B = wp.array2d(
340
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
341
- )
342
- loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
343
-
344
- # test tape
345
- tape = wp.Tape()
346
- with tape:
347
- D = A @ B
348
- wp.launch(matrix_sum_kernel, dim=(m, n), inputs=[D, loss], device=device)
349
-
350
- tape.backward(loss=loss)
351
-
352
- # test adjoint
353
- D.grad = wp.ones((m, n), dtype=float, device=device)
354
- B_transpose = wp.array2d(B.transpose().numpy(), dtype=float, device=device)
355
-
356
- adj_A = D.grad @ B_transpose
357
- assert_array_equal(adj_A, A.grad)
358
-
359
- # test zero
360
- tape.zero()
361
- assert_array_equal(A.grad, wp.zeros_like(A))
362
-
363
-
364
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
365
- def test_large_batch_count(test, device):
366
- rng = np.random.default_rng(42)
367
- low = -4.5
368
- high = 3.5
369
- m = 2
370
- n = 3
371
- k = 4
372
- batch_count = 65535 * 2 + int(65535 / 2)
373
- A = wp.array3d(
374
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
375
- dtype=float,
376
- device=device,
377
- requires_grad=True,
378
- )
379
- B = wp.array3d(
380
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
381
- dtype=float,
382
- device=device,
383
- requires_grad=True,
384
- )
385
- C = wp.array3d(
386
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
387
- dtype=float,
388
- device=device,
389
- requires_grad=True,
390
- )
391
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=float, device=device, requires_grad=True)
392
- ones = wp.zeros_like(D)
393
- ones.fill_(1.0)
394
-
395
- alpha = 1.0
396
- beta = 1.0
397
-
398
- tape = wp.Tape()
399
- with tape:
400
- wp.batched_matmul(A, B, C, D, alpha=alpha, beta=beta, allow_tf32x3_arith=False)
401
- tape.backward(grads={D: ones})
402
-
403
- D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
404
- assert_np_equal(D.numpy(), D_np)
405
-
406
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
407
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
408
- adj_C_np = beta * ones.numpy()
409
-
410
- assert_np_equal(A.grad.numpy(), adj_A_np)
411
- assert_np_equal(B.grad.numpy(), adj_B_np)
412
- assert_np_equal(C.grad.numpy(), adj_C_np)
413
-
414
-
415
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
416
- def test_adjoint_accumulation(test, device):
417
- a_np = np.ones(shape=(2, 3))
418
- b_np = np.ones(shape=(3, 2))
419
- c_np = np.zeros(shape=(2, 2))
420
- d_np = np.zeros(shape=(2, 2))
421
-
422
- a_wp = wp.from_numpy(a_np, dtype=float, requires_grad=True, device=device)
423
- b_wp = wp.from_numpy(b_np, dtype=float, requires_grad=True, device=device)
424
- c_wp = wp.from_numpy(c_np, dtype=float, requires_grad=True, device=device)
425
- d1_wp = wp.from_numpy(d_np, dtype=float, requires_grad=True, device=device)
426
- d2_wp = wp.from_numpy(d_np, dtype=float, requires_grad=True, device=device)
427
-
428
- tape = wp.Tape()
429
-
430
- with tape:
431
- wp.matmul(a_wp, b_wp, c_wp, d1_wp, alpha=1.0, beta=1.0)
432
- wp.matmul(a_wp, b_wp, d1_wp, d2_wp, alpha=1.0, beta=1.0)
433
-
434
- d_grad = wp.zeros_like(d2_wp, device=device)
435
- d_grad.fill_(1.0)
436
- grads = {d2_wp: d_grad}
437
- tape.backward(grads=grads)
438
-
439
- assert_np_equal(a_wp.grad.numpy(), 4.0 * np.ones(shape=(2, 3)))
440
- assert_np_equal(b_wp.grad.numpy(), 4.0 * np.ones(shape=(3, 2)))
441
- assert_np_equal(c_wp.grad.numpy(), np.ones(shape=(2, 2)))
442
-
443
-
444
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
445
- def test_cuda_graph_capture(test, device):
446
- @wp.kernel
447
- def mat_sum(mat: wp.array2d(dtype=Any), loss: wp.array(dtype=Any)):
448
- i, j = wp.tid()
449
- e = mat[i, j]
450
- wp.atomic_add(loss, 0, e)
451
-
452
- for T in [wp.float16, wp.float32, wp.float64]:
453
- wp.overload(mat_sum, [wp.array2d(dtype=T), wp.array(dtype=T)])
454
-
455
- wp.load_module(device=device)
456
- wp.load_module(module="warp.utils", device=device)
457
-
458
- for dtype in [wp.float16, wp.float32, wp.float64]:
459
- m = 8
460
- n = 8
461
- k = 8
462
-
463
- A = wp.ones((m, n), dtype=dtype, device=device, requires_grad=True)
464
- B = wp.ones((n, k), dtype=dtype, device=device, requires_grad=True)
465
- C = wp.zeros((m, k), dtype=dtype, device=device, requires_grad=True)
466
- D = wp.zeros((m, k), dtype=dtype, device=device, requires_grad=True)
467
-
468
- loss = wp.zeros(1, dtype=dtype, device=device, requires_grad=True)
469
-
470
- wp.capture_begin(device, force_module_load=False)
471
- try:
472
- tape = wp.Tape()
473
-
474
- with tape:
475
- wp.matmul(A, B, C, D)
476
- wp.launch(mat_sum, dim=(m, k), inputs=[D, loss], device=device)
477
-
478
- tape.backward(loss=loss)
479
- finally:
480
- graph = wp.capture_end(device)
481
-
482
- wp.capture_launch(graph)
483
-
484
- assert_np_equal(A.grad.numpy(), 8.0 * np.ones((m, n), dtype=wp.types.warp_type_to_np_dtype[dtype]))
485
-
486
-
487
- devices = get_test_devices()
488
- cuda_devices = get_selected_cuda_test_devices()
489
-
490
-
491
- class TestMatmul(unittest.TestCase):
492
- pass
493
-
494
-
495
- # add_function_test(TestMatmul, "test_f16", test_f16, devices=devices)
496
- add_function_test(TestMatmul, "test_f32", test_f32, devices=devices, check_output=False)
497
- add_function_test(TestMatmul, "test_f64", test_f64, devices=devices, check_output=False)
498
- add_function_test(TestMatmul, "test_tape", test_tape, devices=devices, check_output=False)
499
- add_function_test(TestMatmul, "test_operator", test_operator, devices=devices, check_output=False)
500
- add_function_test(TestMatmul, "test_large_batch_count", test_large_batch_count, devices=devices, check_output=False)
501
- add_function_test(
502
- TestMatmul, "test_adjoint_accumulation", test_adjoint_accumulation, devices=devices, check_output=False
503
- )
504
- add_function_test(
505
- TestMatmul, "test_cuda_graph_capture", test_cuda_graph_capture, devices=cuda_devices, check_output=False
506
- )
507
-
508
-
509
- if __name__ == "__main__":
510
- wp.clear_kernel_cache()
511
- unittest.main(verbosity=2, failfast=False)