warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/warp-clang.dll +0 -0
  3. warp/bin/warp.dll +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,411 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import unittest
17
-
18
- import numpy as np
19
-
20
- import warp as wp
21
- from warp.tests.unittest_utils import *
22
-
23
- wp.init() # For wp.context.runtime.core.is_cutlass_enabled()
24
-
25
-
26
- class gemm_test_bed_runner:
27
- def __init__(self, dtype, device):
28
- self.dtype = dtype
29
- self.device = device
30
-
31
- def alloc(self, m, n, k, batch_count):
32
- rng = np.random.default_rng(42)
33
- low = -4.5
34
- high = 3.5
35
- if batch_count == 1:
36
- A = wp.array2d(
37
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
38
- dtype=self.dtype,
39
- device=self.device,
40
- requires_grad=True,
41
- )
42
- B = wp.array2d(
43
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
44
- dtype=self.dtype,
45
- device=self.device,
46
- requires_grad=True,
47
- )
48
- C = wp.array2d(
49
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
50
- dtype=self.dtype,
51
- device=self.device,
52
- requires_grad=True,
53
- )
54
- D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
55
- else:
56
- A = wp.array3d(
57
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
58
- dtype=self.dtype,
59
- device=self.device,
60
- requires_grad=True,
61
- )
62
- B = wp.array3d(
63
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
64
- dtype=self.dtype,
65
- device=self.device,
66
- requires_grad=True,
67
- )
68
- C = wp.array3d(
69
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
70
- dtype=self.dtype,
71
- device=self.device,
72
- requires_grad=True,
73
- )
74
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
75
- return A, B, C, D
76
-
77
- def run_and_verify(self, m, n, k, batch_count, alpha, beta):
78
- A, B, C, D = self.alloc(m, n, k, batch_count)
79
- ones = wp.zeros_like(D)
80
- ones.fill_(1.0)
81
-
82
- if batch_count == 1:
83
- tape = wp.Tape()
84
- with tape:
85
- wp.matmul(A, B, C, D, alpha, beta, False)
86
- tape.backward(grads={D: ones})
87
-
88
- D_np = alpha * (A.numpy() @ B.numpy()) + beta * C.numpy()
89
- assert_np_equal(D.numpy(), D_np)
90
-
91
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose())
92
- adj_B_np = alpha * (A.numpy().transpose() @ ones.numpy())
93
- adj_C_np = beta * ones.numpy()
94
-
95
- else:
96
- tape = wp.Tape()
97
- with tape:
98
- wp.batched_matmul(A, B, C, D, alpha, beta, False)
99
- tape.backward(grads={D: ones})
100
-
101
- D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
102
- assert_np_equal(D.numpy(), D_np)
103
-
104
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
105
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
106
- adj_C_np = beta * ones.numpy()
107
-
108
- assert_np_equal(A.grad.numpy(), adj_A_np)
109
- assert_np_equal(B.grad.numpy(), adj_B_np)
110
- assert_np_equal(C.grad.numpy(), adj_C_np)
111
-
112
- def run(self):
113
- m = 8
114
- n = 16
115
- k = 32
116
- batch_count = 1
117
- beta = 1.0
118
- alpha = 1.0
119
-
120
- self.run_and_verify(m, n, k, batch_count, alpha, beta)
121
-
122
-
123
- class gemm_test_bed_runner_transpose:
124
- def __init__(self, dtype, device):
125
- self.dtype = dtype
126
- self.device = device
127
-
128
- def alloc(self, m, n, k, batch_count):
129
- rng = np.random.default_rng(42)
130
- low = -4.5
131
- high = 3.5
132
- if batch_count == 1:
133
- A = wp.array2d(
134
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
135
- dtype=self.dtype,
136
- device=self.device,
137
- requires_grad=True,
138
- )
139
- B = wp.array2d(
140
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
141
- dtype=self.dtype,
142
- device=self.device,
143
- requires_grad=True,
144
- )
145
- C = wp.array2d(
146
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
147
- dtype=self.dtype,
148
- device=self.device,
149
- requires_grad=True,
150
- )
151
- D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
152
- AT = wp.array2d(A.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
153
- BT = wp.array2d(B.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
154
- else:
155
- A = wp.array3d(
156
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
157
- dtype=self.dtype,
158
- device=self.device,
159
- requires_grad=True,
160
- )
161
- B = wp.array3d(
162
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
163
- dtype=self.dtype,
164
- device=self.device,
165
- requires_grad=True,
166
- )
167
- C = wp.array3d(
168
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
169
- dtype=self.dtype,
170
- device=self.device,
171
- requires_grad=True,
172
- )
173
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
174
- AT = wp.array3d(A.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
175
- BT = wp.array3d(B.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
176
- return A, B, C, D, AT, BT
177
-
178
- def run_and_verify(self, m, n, k, batch_count, alpha, beta):
179
- A, B, C1, D1, AT1, BT1 = self.alloc(m, n, k, batch_count)
180
- C2 = wp.clone(C1)
181
- C3 = wp.clone(C1)
182
- D2 = wp.clone(D1)
183
- D3 = wp.clone(D1)
184
- AT2 = wp.clone(AT1)
185
- BT2 = wp.clone(BT1)
186
- ones1 = wp.zeros_like(D1)
187
- ones1.fill_(1.0)
188
- ones2 = wp.zeros_like(D2)
189
- ones2.fill_(1.0)
190
- ones3 = wp.zeros_like(D3)
191
- ones3.fill_(1.0)
192
-
193
- if batch_count == 1:
194
- ATT1 = AT1.transpose([1, 0])
195
- BTT1 = BT1.transpose([1, 0])
196
- ATT2 = AT2.transpose([1, 0])
197
- BTT2 = BT2.transpose([1, 0])
198
- tape = wp.Tape()
199
- with tape:
200
- wp.matmul(A, BTT1, C1, D1, alpha, beta, False)
201
- wp.matmul(ATT1, B, C2, D2, alpha, beta, False)
202
- wp.matmul(ATT2, BTT2, C3, D3, alpha, beta, False)
203
- tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
204
-
205
- D_np = alpha * (A.numpy() @ B.numpy()) + beta * C1.numpy()
206
- assert_np_equal(D1.numpy(), D_np)
207
- assert_np_equal(D2.numpy(), D_np)
208
- assert_np_equal(D3.numpy(), D_np)
209
-
210
- adj_A_np = alpha * (ones1.numpy() @ B.numpy().transpose())
211
- adj_B_np = alpha * (A.numpy().transpose() @ ones1.numpy())
212
- adj_C_np = beta * ones1.numpy()
213
-
214
- else:
215
- ATT1 = AT1.transpose([0, 2, 1])
216
- BTT1 = BT1.transpose([0, 2, 1])
217
- ATT2 = AT2.transpose([0, 2, 1])
218
- BTT2 = BT2.transpose([0, 2, 1])
219
- tape = wp.Tape()
220
- with tape:
221
- wp.batched_matmul(A, BTT1, C1, D1, alpha, beta, False)
222
- wp.batched_matmul(ATT1, B, C2, D2, alpha, beta, False)
223
- wp.batched_matmul(ATT2, BTT2, C3, D3, alpha, beta, False)
224
- tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
225
-
226
- D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C1.numpy()
227
- assert_np_equal(D1.numpy(), D_np)
228
- assert_np_equal(D2.numpy(), D_np)
229
- assert_np_equal(D3.numpy(), D_np)
230
-
231
- adj_A_np = alpha * np.matmul(ones1.numpy(), B.numpy().transpose((0, 2, 1)))
232
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones1.numpy())
233
- adj_C_np = beta * ones1.numpy()
234
-
235
- assert_np_equal(A.grad.numpy(), adj_A_np)
236
- assert_np_equal(ATT1.grad.numpy(), adj_A_np)
237
- assert_np_equal(ATT2.grad.numpy(), adj_A_np)
238
- assert_np_equal(B.grad.numpy(), adj_B_np)
239
- assert_np_equal(BTT1.grad.numpy(), adj_B_np)
240
- assert_np_equal(BTT2.grad.numpy(), adj_B_np)
241
- assert_np_equal(C1.grad.numpy(), adj_C_np)
242
- assert_np_equal(C2.grad.numpy(), adj_C_np)
243
- assert_np_equal(C3.grad.numpy(), adj_C_np)
244
-
245
- def run(self):
246
- m = 8
247
- n = 16
248
- k = 32
249
- batch_counts = [1, 4]
250
- beta = 1.0
251
- alpha = 1.0
252
-
253
- for batch_count in batch_counts:
254
- self.run_and_verify(m, n, k, batch_count, alpha, beta)
255
-
256
-
257
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
258
- def test_f32(test, device):
259
- gemm_test_bed_runner(wp.float32, device).run()
260
- gemm_test_bed_runner_transpose(wp.float32, device).run()
261
-
262
-
263
- @wp.kernel
264
- def matrix_sum_kernel(arr: wp.array2d(dtype=float), loss: wp.array(dtype=float)):
265
- i, j = wp.tid()
266
- wp.atomic_add(loss, 0, arr[i, j])
267
-
268
-
269
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
270
- def test_tape(test, device):
271
- rng = np.random.default_rng(42)
272
- low = -4.5
273
- high = 3.5
274
- m = 8
275
- n = 16
276
- k = 32
277
- A = wp.array2d(
278
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
279
- )
280
- B = wp.array2d(
281
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
282
- )
283
- C = wp.array2d(
284
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))), dtype=float, device=device, requires_grad=True
285
- )
286
- D = wp.array2d(np.zeros((m, n)), dtype=float, device=device, requires_grad=True)
287
- loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
288
-
289
- # test tape
290
- tape = wp.Tape()
291
- with tape:
292
- wp.matmul(A, B, C, D)
293
- wp.launch(matrix_sum_kernel, dim=(m, n), inputs=[D, loss], device=device)
294
-
295
- tape.backward(loss=loss)
296
- A_grad = A.grad.numpy()
297
- tape.reset()
298
-
299
- # test adjoint
300
- D.grad = wp.ones((m, n), dtype=float, device=device)
301
- wp.adj_matmul(A, B, C, A.grad, B.grad, C.grad, D.grad)
302
- assert_np_equal(A_grad, A.grad.numpy())
303
-
304
- # test zero
305
- tape.zero()
306
- assert_array_equal(A.grad, wp.zeros_like(A))
307
-
308
-
309
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
310
- def test_operator(test, device):
311
- rng = np.random.default_rng(42)
312
- low = -4.5
313
- high = 3.5
314
- m = 8
315
- n = 16
316
- k = 32
317
- A = wp.array2d(
318
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
319
- )
320
- B = wp.array2d(
321
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
322
- )
323
- loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
324
-
325
- # test tape
326
- tape = wp.Tape()
327
- with tape:
328
- D = A @ B
329
- wp.launch(matrix_sum_kernel, dim=(m, n), inputs=[D, loss], device=device)
330
-
331
- tape.backward(loss=loss)
332
-
333
- # test adjoint
334
- D.grad = wp.ones((m, n), dtype=float, device=device)
335
- B_transpose = wp.array2d(B.transpose().numpy(), dtype=float, device=device)
336
-
337
- adj_A = D.grad @ B_transpose
338
- assert_array_equal(adj_A, A.grad)
339
-
340
- # test zero
341
- tape.zero()
342
- assert_array_equal(A.grad, wp.zeros_like(A))
343
-
344
-
345
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
346
- def test_large_batch_count(test, device):
347
- rng = np.random.default_rng(42)
348
- low = -4.5
349
- high = 3.5
350
- m = 2
351
- n = 3
352
- k = 4
353
- batch_count = 65535 * 2 + int(65535 / 2)
354
- A = wp.array3d(
355
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
356
- dtype=float,
357
- device=device,
358
- requires_grad=True,
359
- )
360
- B = wp.array3d(
361
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
362
- dtype=float,
363
- device=device,
364
- requires_grad=True,
365
- )
366
- C = wp.array3d(
367
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
368
- dtype=float,
369
- device=device,
370
- requires_grad=True,
371
- )
372
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=float, device=device, requires_grad=True)
373
- ones = wp.zeros_like(D)
374
- ones.fill_(1.0)
375
-
376
- alpha = 1.0
377
- beta = 1.0
378
-
379
- tape = wp.Tape()
380
- with tape:
381
- wp.batched_matmul(A, B, C, D, alpha=alpha, beta=beta, allow_tf32x3_arith=False)
382
- tape.backward(grads={D: ones})
383
-
384
- D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
385
- assert_np_equal(D.numpy(), D_np)
386
-
387
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
388
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
389
- adj_C_np = beta * ones.numpy()
390
-
391
- assert_np_equal(A.grad.numpy(), adj_A_np)
392
- assert_np_equal(B.grad.numpy(), adj_B_np)
393
- assert_np_equal(C.grad.numpy(), adj_C_np)
394
-
395
-
396
- devices = get_test_devices()
397
-
398
-
399
- class TestMatmulLite(unittest.TestCase):
400
- pass
401
-
402
-
403
- add_function_test(TestMatmulLite, "test_f32", test_f32, devices=devices, check_output=False)
404
- add_function_test(TestMatmulLite, "test_tape", test_tape, devices=devices, check_output=False)
405
- add_function_test(TestMatmulLite, "test_operator", test_operator, devices=devices, check_output=False)
406
- add_function_test(TestMatmulLite, "test_large_batch_count", test_large_batch_count, devices=devices, check_output=False)
407
-
408
-
409
- if __name__ == "__main__":
410
- wp.clear_kernel_cache()
411
- unittest.main(verbosity=2, failfast=False)