warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/tests/test_math.py CHANGED
@@ -5,13 +5,13 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- from typing import NamedTuple
9
8
  import unittest
9
+ from typing import NamedTuple
10
10
 
11
11
  import numpy as np
12
12
 
13
13
  import warp as wp
14
- from warp.tests.test_base import *
14
+ from warp.tests.unittest_utils import *
15
15
 
16
16
  wp.init()
17
17
 
@@ -176,19 +176,18 @@ def test_mat_type(test, device):
176
176
  raise ValueError("mat to string error")
177
177
 
178
178
 
179
- def register(parent):
180
- devices = get_test_devices()
179
+ devices = get_test_devices()
180
+
181
+
182
+ class TestMath(unittest.TestCase):
183
+ pass
181
184
 
182
- class TestMath(parent):
183
- pass
184
185
 
185
- add_function_test(TestMath, "test_scalar_math", test_scalar_math, devices=devices)
186
- add_function_test(TestMath, "test_vec_type", test_vec_type, devices=devices)
187
- add_function_test(TestMath, "test_mat_type", test_mat_type, devices=devices)
188
- return TestMath
186
+ add_function_test(TestMath, "test_scalar_math", test_scalar_math, devices=devices)
187
+ add_function_test(TestMath, "test_vec_type", test_vec_type, devices=devices)
188
+ add_function_test(TestMath, "test_mat_type", test_mat_type, devices=devices)
189
189
 
190
190
 
191
191
  if __name__ == "__main__":
192
192
  wp.build.clear_kernel_cache()
193
- _ = register(unittest.TestCase)
194
193
  unittest.main(verbosity=2)
warp/tests/test_matmul.py CHANGED
@@ -1,11 +1,21 @@
1
- import numpy as np
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
2
8
  import unittest
3
9
 
10
+ import numpy as np
11
+
4
12
  import warp as wp
5
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
6
14
 
7
15
  wp.init()
8
16
 
17
+ from warp.context import runtime # noqa: E402
18
+
9
19
 
10
20
  class gemm_test_bed_runner:
11
21
  def __init__(self, dtype, device):
@@ -21,63 +31,54 @@ class gemm_test_bed_runner:
21
31
  np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
22
32
  dtype=self.dtype,
23
33
  device=self.device,
24
- requires_grad=True
34
+ requires_grad=True,
25
35
  )
26
36
  B = wp.array2d(
27
37
  np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
28
38
  dtype=self.dtype,
29
39
  device=self.device,
30
- requires_grad=True
40
+ requires_grad=True,
31
41
  )
32
42
  C = wp.array2d(
33
43
  np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
34
44
  dtype=self.dtype,
35
45
  device=self.device,
36
- requires_grad=True
46
+ requires_grad=True,
37
47
  )
38
- D = wp.array2d(
39
- np.zeros((m, n)),
40
- dtype=self.dtype,
41
- device=self.device,
42
- requires_grad=True)
48
+ D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
43
49
  else:
44
50
  A = wp.array3d(
45
51
  np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
46
52
  dtype=self.dtype,
47
53
  device=self.device,
48
- requires_grad=True
54
+ requires_grad=True,
49
55
  )
50
56
  B = wp.array3d(
51
57
  np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
52
58
  dtype=self.dtype,
53
59
  device=self.device,
54
- requires_grad=True
60
+ requires_grad=True,
55
61
  )
56
62
  C = wp.array3d(
57
63
  np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
58
64
  dtype=self.dtype,
59
65
  device=self.device,
60
- requires_grad=True
61
- )
62
- D = wp.array3d(
63
- np.zeros((batch_count, m, n)),
64
- dtype=self.dtype,
65
- device=self.device,
66
- requires_grad=True
66
+ requires_grad=True,
67
67
  )
68
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
68
69
  return A, B, C, D
69
70
 
70
71
  def run_and_verify(self, m, n, k, batch_count, alpha, beta):
71
72
  A, B, C, D = self.alloc(m, n, k, batch_count)
72
73
  ones = wp.zeros_like(D)
73
74
  ones.fill_(1.0)
74
-
75
+
75
76
  if batch_count == 1:
76
77
  tape = wp.Tape()
77
78
  with tape:
78
79
  wp.matmul(A, B, C, D, alpha, beta, False, self.device)
79
- tape.backward(grads={D : ones})
80
-
80
+ tape.backward(grads={D: ones})
81
+
81
82
  D_np = alpha * (A.numpy() @ B.numpy()) + beta * C.numpy()
82
83
  assert np.array_equal(D_np, D.numpy())
83
84
 
@@ -89,8 +90,8 @@ class gemm_test_bed_runner:
89
90
  tape = wp.Tape()
90
91
  with tape:
91
92
  wp.batched_matmul(A, B, C, D, alpha, beta, False, self.device)
92
- tape.backward(grads={D : ones})
93
-
93
+ tape.backward(grads={D: ones})
94
+
94
95
  D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
95
96
  assert np.array_equal(D_np, D.numpy())
96
97
 
@@ -132,75 +133,45 @@ class gemm_test_bed_runner_transpose:
132
133
  np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
133
134
  dtype=self.dtype,
134
135
  device=self.device,
135
- requires_grad=True
136
+ requires_grad=True,
136
137
  )
137
138
  B = wp.array2d(
138
139
  np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
139
140
  dtype=self.dtype,
140
141
  device=self.device,
141
- requires_grad=True
142
+ requires_grad=True,
142
143
  )
143
144
  C = wp.array2d(
144
145
  np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
145
146
  dtype=self.dtype,
146
147
  device=self.device,
147
- requires_grad=True
148
- )
149
- D = wp.array2d(
150
- np.zeros((m, n)),
151
- dtype=self.dtype,
152
- device=self.device,
153
- requires_grad=True
154
- )
155
- AT = wp.array2d(
156
- A.numpy().transpose([1, 0]),
157
- dtype=self.dtype,
158
- device=self.device,
159
- requires_grad=True
160
- )
161
- BT = wp.array2d(
162
- B.numpy().transpose([1, 0]),
163
- dtype=self.dtype,
164
- device=self.device,
165
- requires_grad=True
148
+ requires_grad=True,
166
149
  )
150
+ D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
151
+ AT = wp.array2d(A.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
152
+ BT = wp.array2d(B.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
167
153
  else:
168
154
  A = wp.array3d(
169
155
  np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
170
156
  dtype=self.dtype,
171
157
  device=self.device,
172
- requires_grad=True
158
+ requires_grad=True,
173
159
  )
174
160
  B = wp.array3d(
175
161
  np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
176
162
  dtype=self.dtype,
177
163
  device=self.device,
178
- requires_grad=True
164
+ requires_grad=True,
179
165
  )
180
166
  C = wp.array3d(
181
167
  np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
182
168
  dtype=self.dtype,
183
169
  device=self.device,
184
- requires_grad=True
185
- )
186
- D = wp.array3d(
187
- np.zeros((batch_count, m, n)),
188
- dtype=self.dtype,
189
- device=self.device,
190
- requires_grad=True
191
- )
192
- AT = wp.array3d(
193
- A.numpy().transpose([0, 2, 1]),
194
- dtype=self.dtype,
195
- device=self.device,
196
- requires_grad=True
197
- )
198
- BT = wp.array3d(
199
- B.numpy().transpose([0, 2, 1]),
200
- dtype=self.dtype,
201
- device=self.device,
202
- requires_grad=True
170
+ requires_grad=True,
203
171
  )
172
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
173
+ AT = wp.array3d(A.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
174
+ BT = wp.array3d(B.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
204
175
  return A, B, C, D, AT, BT
205
176
 
206
177
  def run_and_verify(self, m, n, k, batch_count, alpha, beta):
@@ -219,17 +190,17 @@ class gemm_test_bed_runner_transpose:
219
190
  ones3.fill_(1.0)
220
191
 
221
192
  if batch_count == 1:
222
- ATT1 = AT1.transpose([1, 0])
193
+ ATT1 = AT1.transpose([1, 0])
223
194
  BTT1 = BT1.transpose([1, 0])
224
- ATT2 = AT2.transpose([1, 0])
195
+ ATT2 = AT2.transpose([1, 0])
225
196
  BTT2 = BT2.transpose([1, 0])
226
197
  tape = wp.Tape()
227
198
  with tape:
228
199
  wp.matmul(A, BTT1, C1, D1, alpha, beta, False, self.device)
229
200
  wp.matmul(ATT1, B, C2, D2, alpha, beta, False, self.device)
230
201
  wp.matmul(ATT2, BTT2, C3, D3, alpha, beta, False, self.device)
231
- tape.backward(grads={D1 : ones1, D2 : ones2, D3 : ones3})
232
-
202
+ tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
203
+
233
204
  D_np = alpha * (A.numpy() @ B.numpy()) + beta * C1.numpy()
234
205
  assert np.array_equal(D_np, D1.numpy())
235
206
  assert np.array_equal(D_np, D2.numpy())
@@ -240,7 +211,7 @@ class gemm_test_bed_runner_transpose:
240
211
  adj_C_np = beta * ones1.numpy()
241
212
 
242
213
  else:
243
- ATT1 = AT1.transpose([0, 2, 1])
214
+ ATT1 = AT1.transpose([0, 2, 1])
244
215
  BTT1 = BT1.transpose([0, 2, 1])
245
216
  ATT2 = AT2.transpose([0, 2, 1])
246
217
  BTT2 = BT2.transpose([0, 2, 1])
@@ -249,8 +220,8 @@ class gemm_test_bed_runner_transpose:
249
220
  wp.batched_matmul(A, BTT1, C1, D1, alpha, beta, False, self.device)
250
221
  wp.batched_matmul(ATT1, B, C2, D2, alpha, beta, False, self.device)
251
222
  wp.batched_matmul(ATT2, BTT2, C3, D3, alpha, beta, False, self.device)
252
- tape.backward(grads={D1 : ones1, D2 : ones2, D3 : ones3})
253
-
223
+ tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
224
+
254
225
  D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C1.numpy()
255
226
  assert np.array_equal(D_np, D1.numpy())
256
227
  assert np.array_equal(D_np, D2.numpy())
@@ -288,11 +259,13 @@ def test_f16(test, device):
288
259
  gemm_test_bed_runner_transpose(wp.float16, device).run()
289
260
 
290
261
 
262
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
291
263
  def test_f32(test, device):
292
264
  gemm_test_bed_runner(wp.float32, device).run()
293
265
  gemm_test_bed_runner_transpose(wp.float32, device).run()
294
266
 
295
267
 
268
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
296
269
  def test_f64(test, device):
297
270
  gemm_test_bed_runner(wp.float64, device).run()
298
271
  gemm_test_bed_runner_transpose(wp.float64, device).run()
@@ -304,6 +277,7 @@ def matrix_sum_kernel(arr: wp.array2d(dtype=float), loss: wp.array(dtype=float))
304
277
  wp.atomic_add(loss, 0, arr[i, j])
305
278
 
306
279
 
280
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
307
281
  def test_tape(test, device):
308
282
  rng = np.random.default_rng(42)
309
283
  low = -4.5
@@ -331,6 +305,7 @@ def test_tape(test, device):
331
305
 
332
306
  tape.backward(loss=loss)
333
307
  A_grad = A.grad.numpy()
308
+ tape.reset()
334
309
 
335
310
  # test adjoint
336
311
  D.grad = wp.array2d(np.ones((m, n)), dtype=float, device=device)
@@ -342,6 +317,7 @@ def test_tape(test, device):
342
317
  assert_array_equal(A.grad, wp.zeros_like(A))
343
318
 
344
319
 
320
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
345
321
  def test_operator(test, device):
346
322
  rng = np.random.default_rng(42)
347
323
  low = -4.5
@@ -377,6 +353,7 @@ def test_operator(test, device):
377
353
  assert_array_equal(A.grad, wp.zeros_like(A))
378
354
 
379
355
 
356
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
380
357
  def test_large_batch_count(test, device):
381
358
  rng = np.random.default_rng(42)
382
359
  low = -4.5
@@ -386,31 +363,38 @@ def test_large_batch_count(test, device):
386
363
  k = 4
387
364
  batch_count = 65535 * 2 + int(65535 / 2)
388
365
  A = wp.array3d(
389
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))), dtype=float, device=device, requires_grad=True
366
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
367
+ dtype=float,
368
+ device=device,
369
+ requires_grad=True,
390
370
  )
391
371
  B = wp.array3d(
392
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))), dtype=float, device=device, requires_grad=True
372
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
373
+ dtype=float,
374
+ device=device,
375
+ requires_grad=True,
393
376
  )
394
377
  C = wp.array3d(
395
- np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))), dtype=float, device=device, requires_grad=True
396
- )
397
- D = wp.array3d(
398
- np.zeros((batch_count, m, n)), dtype=float, device=device, requires_grad=True
378
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
379
+ dtype=float,
380
+ device=device,
381
+ requires_grad=True,
399
382
  )
383
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=float, device=device, requires_grad=True)
400
384
  ones = wp.zeros_like(D)
401
385
  ones.fill_(1.0)
402
386
 
403
387
  alpha = 1.0
404
388
  beta = 1.0
405
-
389
+
406
390
  tape = wp.Tape()
407
391
  with tape:
408
392
  wp.batched_matmul(A, B, C, D, alpha=alpha, beta=beta, allow_tf32x3_arith=False, device=device)
409
- tape.backward(grads={D : ones})
393
+ tape.backward(grads={D: ones})
410
394
 
411
395
  D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
412
396
  assert np.array_equal(D_np, D.numpy())
413
-
397
+
414
398
  adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
415
399
  adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
416
400
  adj_C_np = beta * ones.numpy()
@@ -420,30 +404,50 @@ def test_large_batch_count(test, device):
420
404
  assert np.array_equal(adj_C_np, C.grad.numpy())
421
405
 
422
406
 
423
- def register(parent):
424
- devices = [d for d in get_test_devices()]
407
+ def test_adjoint_accumulation(test, device):
408
+ a_np = np.ones(shape=(2,3))
409
+ b_np = np.ones(shape=(3,2))
410
+ c_np = np.zeros(shape=(2,2))
411
+ d_np = np.zeros(shape=(2,2))
425
412
 
426
- class TestMatmul(parent):
427
- pass
413
+ a_wp = wp.from_numpy(a_np, dtype=float, requires_grad=True)
414
+ b_wp = wp.from_numpy(b_np, dtype=float, requires_grad=True)
415
+ c_wp = wp.from_numpy(c_np, dtype=float, requires_grad=True)
416
+ d1_wp = wp.from_numpy(d_np, dtype=float, requires_grad=True)
417
+ d2_wp = wp.from_numpy(d_np, dtype=float, requires_grad=True)
428
418
 
429
- if devices:
430
- # check if CUTLASS is available
431
- from warp.context import runtime
419
+ tape = wp.Tape()
420
+
421
+ with tape:
422
+ wp.matmul(a_wp, b_wp, c_wp, d1_wp, alpha=1.0, beta=1.0)
423
+ wp.matmul(a_wp, b_wp, d1_wp, d2_wp, alpha=1.0, beta=1.0)
424
+
425
+ d_grad = wp.zeros_like(d2_wp)
426
+ d_grad.fill_(1.)
427
+ grads = {d2_wp : d_grad}
428
+ tape.backward(grads=grads)
429
+
430
+ assert np.array_equal(a_wp.grad.numpy(), 4.0 * np.ones(shape=(2,3)))
431
+ assert np.array_equal(b_wp.grad.numpy(), 4.0 * np.ones(shape=(3,2)))
432
+ assert np.array_equal(c_wp.grad.numpy(), np.ones(shape=(2,2)))
433
+
434
+
435
+ devices = get_test_devices()
436
+
437
+
438
+ class TestMatmul(unittest.TestCase):
439
+ pass
432
440
 
433
- if runtime.core.is_cutlass_enabled():
434
- # add_function_test(TestMatmul, "test_f16", test_f16, devices=devices)
435
- add_function_test(TestMatmul, "test_f32", test_f32, devices=devices)
436
- add_function_test(TestMatmul, "test_f64", test_f64, devices=devices)
437
- add_function_test(TestMatmul, "test_tape", test_tape, devices=devices)
438
- add_function_test(TestMatmul, "test_operator", test_operator, devices=devices)
439
- add_function_test(TestMatmul, "test_large_batch_count", test_large_batch_count, devices=devices)
440
- else:
441
- print("Skipping matmul tests because CUTLASS is not supported in this build")
442
441
 
443
- return TestMatmul
442
+ # add_function_test(TestMatmul, "test_f16", test_f16, devices=devices)
443
+ add_function_test(TestMatmul, "test_f32", test_f32, devices=devices)
444
+ add_function_test(TestMatmul, "test_f64", test_f64, devices=devices)
445
+ add_function_test(TestMatmul, "test_tape", test_tape, devices=devices)
446
+ add_function_test(TestMatmul, "test_operator", test_operator, devices=devices)
447
+ add_function_test(TestMatmul, "test_large_batch_count", test_large_batch_count, devices=devices)
448
+ add_function_test(TestMatmul, "test_adjoint_accumulation", test_adjoint_accumulation, devices=devices)
444
449
 
445
450
 
446
451
  if __name__ == "__main__":
447
452
  wp.build.clear_kernel_cache()
448
- _ = register(unittest.TestCase)
449
453
  unittest.main(verbosity=2, failfast=False)