warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
@@ -1,411 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import unittest
17
-
18
- import numpy as np
19
-
20
- import warp as wp
21
- from warp.tests.unittest_utils import *
22
-
23
- wp.init() # For wp.context.runtime.core.is_cutlass_enabled()
24
-
25
-
26
- class gemm_test_bed_runner:
27
- def __init__(self, dtype, device):
28
- self.dtype = dtype
29
- self.device = device
30
-
31
- def alloc(self, m, n, k, batch_count):
32
- rng = np.random.default_rng(42)
33
- low = -4.5
34
- high = 3.5
35
- if batch_count == 1:
36
- A = wp.array2d(
37
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
38
- dtype=self.dtype,
39
- device=self.device,
40
- requires_grad=True,
41
- )
42
- B = wp.array2d(
43
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
44
- dtype=self.dtype,
45
- device=self.device,
46
- requires_grad=True,
47
- )
48
- C = wp.array2d(
49
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
50
- dtype=self.dtype,
51
- device=self.device,
52
- requires_grad=True,
53
- )
54
- D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
55
- else:
56
- A = wp.array3d(
57
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
58
- dtype=self.dtype,
59
- device=self.device,
60
- requires_grad=True,
61
- )
62
- B = wp.array3d(
63
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
64
- dtype=self.dtype,
65
- device=self.device,
66
- requires_grad=True,
67
- )
68
- C = wp.array3d(
69
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
70
- dtype=self.dtype,
71
- device=self.device,
72
- requires_grad=True,
73
- )
74
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
75
- return A, B, C, D
76
-
77
- def run_and_verify(self, m, n, k, batch_count, alpha, beta):
78
- A, B, C, D = self.alloc(m, n, k, batch_count)
79
- ones = wp.zeros_like(D)
80
- ones.fill_(1.0)
81
-
82
- if batch_count == 1:
83
- tape = wp.Tape()
84
- with tape:
85
- wp.matmul(A, B, C, D, alpha, beta, False)
86
- tape.backward(grads={D: ones})
87
-
88
- D_np = alpha * (A.numpy() @ B.numpy()) + beta * C.numpy()
89
- assert_np_equal(D.numpy(), D_np)
90
-
91
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose())
92
- adj_B_np = alpha * (A.numpy().transpose() @ ones.numpy())
93
- adj_C_np = beta * ones.numpy()
94
-
95
- else:
96
- tape = wp.Tape()
97
- with tape:
98
- wp.batched_matmul(A, B, C, D, alpha, beta, False)
99
- tape.backward(grads={D: ones})
100
-
101
- D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
102
- assert_np_equal(D.numpy(), D_np)
103
-
104
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
105
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
106
- adj_C_np = beta * ones.numpy()
107
-
108
- assert_np_equal(A.grad.numpy(), adj_A_np)
109
- assert_np_equal(B.grad.numpy(), adj_B_np)
110
- assert_np_equal(C.grad.numpy(), adj_C_np)
111
-
112
- def run(self):
113
- m = 8
114
- n = 16
115
- k = 32
116
- batch_count = 1
117
- beta = 1.0
118
- alpha = 1.0
119
-
120
- self.run_and_verify(m, n, k, batch_count, alpha, beta)
121
-
122
-
123
- class gemm_test_bed_runner_transpose:
124
- def __init__(self, dtype, device):
125
- self.dtype = dtype
126
- self.device = device
127
-
128
- def alloc(self, m, n, k, batch_count):
129
- rng = np.random.default_rng(42)
130
- low = -4.5
131
- high = 3.5
132
- if batch_count == 1:
133
- A = wp.array2d(
134
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
135
- dtype=self.dtype,
136
- device=self.device,
137
- requires_grad=True,
138
- )
139
- B = wp.array2d(
140
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
141
- dtype=self.dtype,
142
- device=self.device,
143
- requires_grad=True,
144
- )
145
- C = wp.array2d(
146
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
147
- dtype=self.dtype,
148
- device=self.device,
149
- requires_grad=True,
150
- )
151
- D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
152
- AT = wp.array2d(A.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
153
- BT = wp.array2d(B.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
154
- else:
155
- A = wp.array3d(
156
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
157
- dtype=self.dtype,
158
- device=self.device,
159
- requires_grad=True,
160
- )
161
- B = wp.array3d(
162
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
163
- dtype=self.dtype,
164
- device=self.device,
165
- requires_grad=True,
166
- )
167
- C = wp.array3d(
168
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
169
- dtype=self.dtype,
170
- device=self.device,
171
- requires_grad=True,
172
- )
173
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
174
- AT = wp.array3d(A.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
175
- BT = wp.array3d(B.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
176
- return A, B, C, D, AT, BT
177
-
178
- def run_and_verify(self, m, n, k, batch_count, alpha, beta):
179
- A, B, C1, D1, AT1, BT1 = self.alloc(m, n, k, batch_count)
180
- C2 = wp.clone(C1)
181
- C3 = wp.clone(C1)
182
- D2 = wp.clone(D1)
183
- D3 = wp.clone(D1)
184
- AT2 = wp.clone(AT1)
185
- BT2 = wp.clone(BT1)
186
- ones1 = wp.zeros_like(D1)
187
- ones1.fill_(1.0)
188
- ones2 = wp.zeros_like(D2)
189
- ones2.fill_(1.0)
190
- ones3 = wp.zeros_like(D3)
191
- ones3.fill_(1.0)
192
-
193
- if batch_count == 1:
194
- ATT1 = AT1.transpose([1, 0])
195
- BTT1 = BT1.transpose([1, 0])
196
- ATT2 = AT2.transpose([1, 0])
197
- BTT2 = BT2.transpose([1, 0])
198
- tape = wp.Tape()
199
- with tape:
200
- wp.matmul(A, BTT1, C1, D1, alpha, beta, False)
201
- wp.matmul(ATT1, B, C2, D2, alpha, beta, False)
202
- wp.matmul(ATT2, BTT2, C3, D3, alpha, beta, False)
203
- tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
204
-
205
- D_np = alpha * (A.numpy() @ B.numpy()) + beta * C1.numpy()
206
- assert_np_equal(D1.numpy(), D_np)
207
- assert_np_equal(D2.numpy(), D_np)
208
- assert_np_equal(D3.numpy(), D_np)
209
-
210
- adj_A_np = alpha * (ones1.numpy() @ B.numpy().transpose())
211
- adj_B_np = alpha * (A.numpy().transpose() @ ones1.numpy())
212
- adj_C_np = beta * ones1.numpy()
213
-
214
- else:
215
- ATT1 = AT1.transpose([0, 2, 1])
216
- BTT1 = BT1.transpose([0, 2, 1])
217
- ATT2 = AT2.transpose([0, 2, 1])
218
- BTT2 = BT2.transpose([0, 2, 1])
219
- tape = wp.Tape()
220
- with tape:
221
- wp.batched_matmul(A, BTT1, C1, D1, alpha, beta, False)
222
- wp.batched_matmul(ATT1, B, C2, D2, alpha, beta, False)
223
- wp.batched_matmul(ATT2, BTT2, C3, D3, alpha, beta, False)
224
- tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
225
-
226
- D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C1.numpy()
227
- assert_np_equal(D1.numpy(), D_np)
228
- assert_np_equal(D2.numpy(), D_np)
229
- assert_np_equal(D3.numpy(), D_np)
230
-
231
- adj_A_np = alpha * np.matmul(ones1.numpy(), B.numpy().transpose((0, 2, 1)))
232
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones1.numpy())
233
- adj_C_np = beta * ones1.numpy()
234
-
235
- assert_np_equal(A.grad.numpy(), adj_A_np)
236
- assert_np_equal(ATT1.grad.numpy(), adj_A_np)
237
- assert_np_equal(ATT2.grad.numpy(), adj_A_np)
238
- assert_np_equal(B.grad.numpy(), adj_B_np)
239
- assert_np_equal(BTT1.grad.numpy(), adj_B_np)
240
- assert_np_equal(BTT2.grad.numpy(), adj_B_np)
241
- assert_np_equal(C1.grad.numpy(), adj_C_np)
242
- assert_np_equal(C2.grad.numpy(), adj_C_np)
243
- assert_np_equal(C3.grad.numpy(), adj_C_np)
244
-
245
- def run(self):
246
- m = 8
247
- n = 16
248
- k = 32
249
- batch_counts = [1, 4]
250
- beta = 1.0
251
- alpha = 1.0
252
-
253
- for batch_count in batch_counts:
254
- self.run_and_verify(m, n, k, batch_count, alpha, beta)
255
-
256
-
257
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
258
- def test_f32(test, device):
259
- gemm_test_bed_runner(wp.float32, device).run()
260
- gemm_test_bed_runner_transpose(wp.float32, device).run()
261
-
262
-
263
- @wp.kernel
264
- def matrix_sum_kernel(arr: wp.array2d(dtype=float), loss: wp.array(dtype=float)):
265
- i, j = wp.tid()
266
- wp.atomic_add(loss, 0, arr[i, j])
267
-
268
-
269
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
270
- def test_tape(test, device):
271
- rng = np.random.default_rng(42)
272
- low = -4.5
273
- high = 3.5
274
- m = 8
275
- n = 16
276
- k = 32
277
- A = wp.array2d(
278
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
279
- )
280
- B = wp.array2d(
281
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
282
- )
283
- C = wp.array2d(
284
- np.ceil(rng.uniform(low=low, high=high, size=(m, n))), dtype=float, device=device, requires_grad=True
285
- )
286
- D = wp.array2d(np.zeros((m, n)), dtype=float, device=device, requires_grad=True)
287
- loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
288
-
289
- # test tape
290
- tape = wp.Tape()
291
- with tape:
292
- wp.matmul(A, B, C, D)
293
- wp.launch(matrix_sum_kernel, dim=(m, n), inputs=[D, loss], device=device)
294
-
295
- tape.backward(loss=loss)
296
- A_grad = A.grad.numpy()
297
- tape.reset()
298
-
299
- # test adjoint
300
- D.grad = wp.ones((m, n), dtype=float, device=device)
301
- wp.adj_matmul(A, B, C, A.grad, B.grad, C.grad, D.grad)
302
- assert_np_equal(A_grad, A.grad.numpy())
303
-
304
- # test zero
305
- tape.zero()
306
- assert_array_equal(A.grad, wp.zeros_like(A))
307
-
308
-
309
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
310
- def test_operator(test, device):
311
- rng = np.random.default_rng(42)
312
- low = -4.5
313
- high = 3.5
314
- m = 8
315
- n = 16
316
- k = 32
317
- A = wp.array2d(
318
- np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
319
- )
320
- B = wp.array2d(
321
- np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
322
- )
323
- loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
324
-
325
- # test tape
326
- tape = wp.Tape()
327
- with tape:
328
- D = A @ B
329
- wp.launch(matrix_sum_kernel, dim=(m, n), inputs=[D, loss], device=device)
330
-
331
- tape.backward(loss=loss)
332
-
333
- # test adjoint
334
- D.grad = wp.ones((m, n), dtype=float, device=device)
335
- B_transpose = wp.array2d(B.transpose().numpy(), dtype=float, device=device)
336
-
337
- adj_A = D.grad @ B_transpose
338
- assert_array_equal(adj_A, A.grad)
339
-
340
- # test zero
341
- tape.zero()
342
- assert_array_equal(A.grad, wp.zeros_like(A))
343
-
344
-
345
- @unittest.skipUnless(wp.context.runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
346
- def test_large_batch_count(test, device):
347
- rng = np.random.default_rng(42)
348
- low = -4.5
349
- high = 3.5
350
- m = 2
351
- n = 3
352
- k = 4
353
- batch_count = 65535 * 2 + int(65535 / 2)
354
- A = wp.array3d(
355
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
356
- dtype=float,
357
- device=device,
358
- requires_grad=True,
359
- )
360
- B = wp.array3d(
361
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
362
- dtype=float,
363
- device=device,
364
- requires_grad=True,
365
- )
366
- C = wp.array3d(
367
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
368
- dtype=float,
369
- device=device,
370
- requires_grad=True,
371
- )
372
- D = wp.array3d(np.zeros((batch_count, m, n)), dtype=float, device=device, requires_grad=True)
373
- ones = wp.zeros_like(D)
374
- ones.fill_(1.0)
375
-
376
- alpha = 1.0
377
- beta = 1.0
378
-
379
- tape = wp.Tape()
380
- with tape:
381
- wp.batched_matmul(A, B, C, D, alpha=alpha, beta=beta, allow_tf32x3_arith=False)
382
- tape.backward(grads={D: ones})
383
-
384
- D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
385
- assert_np_equal(D.numpy(), D_np)
386
-
387
- adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
388
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
389
- adj_C_np = beta * ones.numpy()
390
-
391
- assert_np_equal(A.grad.numpy(), adj_A_np)
392
- assert_np_equal(B.grad.numpy(), adj_B_np)
393
- assert_np_equal(C.grad.numpy(), adj_C_np)
394
-
395
-
396
- devices = get_test_devices()
397
-
398
-
399
- class TestMatmulLite(unittest.TestCase):
400
- pass
401
-
402
-
403
- add_function_test(TestMatmulLite, "test_f32", test_f32, devices=devices, check_output=False)
404
- add_function_test(TestMatmulLite, "test_tape", test_tape, devices=devices, check_output=False)
405
- add_function_test(TestMatmulLite, "test_operator", test_operator, devices=devices, check_output=False)
406
- add_function_test(TestMatmulLite, "test_large_batch_count", test_large_batch_count, devices=devices, check_output=False)
407
-
408
-
409
- if __name__ == "__main__":
410
- wp.clear_kernel_cache()
411
- unittest.main(verbosity=2, failfast=False)