warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/tests/test_math.py CHANGED
@@ -5,13 +5,13 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- from typing import NamedTuple
9
8
  import unittest
9
+ from typing import NamedTuple
10
10
 
11
11
  import numpy as np
12
12
 
13
13
  import warp as wp
14
- from warp.tests.test_base import *
14
+ from warp.tests.unittest_utils import *
15
15
 
16
16
  wp.init()
17
17
 
@@ -84,16 +84,110 @@ def test_scalar_math(test, device):
84
84
  )
85
85
 
86
86
 
87
- def register(parent):
88
- devices = get_test_devices()
87
+ def test_vec_type(test, device):
88
+ vec5 = wp.vec(length=5, dtype=float)
89
+ v = vec5()
90
+ w = vec5()
91
+ a = vec5(1.0)
92
+ b = vec5(0.0, 0.0, 0.0, 0.0, 0.0)
93
+ c = vec5(0.0)
94
+
95
+ v[0] = 1.0
96
+ v.x = 0.0
97
+ v[1:] = [1.0, 1.0, 1.0, 1.0]
98
+
99
+ w[0] = 1.0
100
+ w[1:] = [0.0, 0.0, 0.0, 0.0]
101
+
102
+ if v[0] != w[1] or v.x != w.y:
103
+ raise ValueError("vec setter error")
104
+
105
+ for x in v[1:]:
106
+ if x != 1.0:
107
+ raise ValueError("vec slicing error")
108
+
109
+ if b != c:
110
+ raise ValueError("vec equality error")
111
+
112
+ if str(v) != "[0.0, 1.0, 1.0, 1.0, 1.0]":
113
+ raise ValueError("vec to string error")
114
+
115
+
116
+ def test_mat_type(test, device):
117
+ mat55 = wp.mat(shape=(5, 5), dtype=float)
118
+ m1 = mat55()
119
+ m2 = mat55()
120
+
121
+ for i in range(5):
122
+ for j in range(5):
123
+ if i == j:
124
+ m1[i, j] = 1.0
125
+ else:
126
+ m1[i, j] = 0.0
127
+
128
+ for i in range(5):
129
+ m2[i] = [1.0, 1.0, 1.0, 1.0, 1.0]
130
+
131
+ a = mat55(1.0)
132
+ b = mat55(
133
+ 1.0,
134
+ 0.0,
135
+ 0.0,
136
+ 0.0,
137
+ 0.0,
138
+ 0.0,
139
+ 1.0,
140
+ 0.0,
141
+ 0.0,
142
+ 0.0,
143
+ 0.0,
144
+ 0.0,
145
+ 1.0,
146
+ 0.0,
147
+ 0.0,
148
+ 0.0,
149
+ 0.0,
150
+ 0.0,
151
+ 1.0,
152
+ 0.0,
153
+ 0.0,
154
+ 0.0,
155
+ 0.0,
156
+ 0.0,
157
+ 1.0,
158
+ )
159
+
160
+ if m1 != b:
161
+ raise ValueError("mat element setting error")
162
+
163
+ if m2 != a:
164
+ raise ValueError("mat row setting error")
165
+
166
+ if m1[0, 0] != 1.0:
167
+ raise ValueError("mat element getting error")
168
+
169
+ if m2[0] != [1.0, 1.0, 1.0, 1.0, 1.0]:
170
+ raise ValueError("mat row getting error")
171
+
172
+ if (
173
+ str(b)
174
+ != "[[1.0, 0.0, 0.0, 0.0, 0.0],\n [0.0, 1.0, 0.0, 0.0, 0.0],\n [0.0, 0.0, 1.0, 0.0, 0.0],\n [0.0, 0.0, 0.0, 1.0, 0.0],\n [0.0, 0.0, 0.0, 0.0, 1.0]]"
175
+ ):
176
+ raise ValueError("mat to string error")
177
+
178
+
179
+ devices = get_test_devices()
180
+
181
+
182
+ class TestMath(unittest.TestCase):
183
+ pass
89
184
 
90
- class TestMath(parent):
91
- pass
92
185
 
93
- add_function_test(TestMath, "test_scalar_math", test_scalar_math, devices=devices)
94
- return TestMath
186
+ add_function_test(TestMath, "test_scalar_math", test_scalar_math, devices=devices)
187
+ add_function_test(TestMath, "test_vec_type", test_vec_type, devices=devices)
188
+ add_function_test(TestMath, "test_mat_type", test_mat_type, devices=devices)
95
189
 
96
190
 
97
191
  if __name__ == "__main__":
98
- _ = register(unittest.TestCase)
192
+ wp.build.clear_kernel_cache()
99
193
  unittest.main(verbosity=2)
warp/tests/test_matmul.py CHANGED
@@ -1,88 +1,107 @@
1
- import numpy as np
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
2
8
  import unittest
3
9
 
4
- import warp as wp
5
- from warp.tests.test_base import *
10
+ import numpy as np
6
11
 
7
- np.random.seed(0)
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
8
14
 
9
15
  wp.init()
10
- wp.config.mode = "debug"
16
+
17
+ from warp.context import runtime # noqa: E402
11
18
 
12
19
 
13
- class GemmTestbedRunner:
20
+ class gemm_test_bed_runner:
14
21
  def __init__(self, dtype, device):
15
22
  self.dtype = dtype
16
23
  self.device = device
17
24
 
18
25
  def alloc(self, m, n, k, batch_count):
26
+ rng = np.random.default_rng(42)
19
27
  low = -4.5
20
28
  high = 3.5
21
29
  if batch_count == 1:
22
30
  A = wp.array2d(
23
- np.ceil(np.random.uniform(low=low, high=high, size=(m, k))), dtype=self.dtype, device=self.device
31
+ np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
32
+ dtype=self.dtype,
33
+ device=self.device,
34
+ requires_grad=True,
24
35
  )
25
36
  B = wp.array2d(
26
- np.ceil(np.random.uniform(low=low, high=high, size=(k, n))), dtype=self.dtype, device=self.device
37
+ np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
38
+ dtype=self.dtype,
39
+ device=self.device,
40
+ requires_grad=True,
27
41
  )
28
42
  C = wp.array2d(
29
- np.ceil(np.random.uniform(low=low, high=high, size=(m, n))), dtype=self.dtype, device=self.device
43
+ np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
44
+ dtype=self.dtype,
45
+ device=self.device,
46
+ requires_grad=True,
30
47
  )
31
- D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device)
32
- adj_A = wp.array2d(np.zeros((m, k)), dtype=self.dtype, device=self.device)
33
- adj_B = wp.array2d(np.zeros((k, n)), dtype=self.dtype, device=self.device)
34
- adj_C = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device)
35
- adj_D = wp.array2d(np.ones((m, n)), dtype=self.dtype, device=self.device)
48
+ D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
36
49
  else:
37
- A = wp.array2d(
38
- np.ceil(np.random.uniform(low=low, high=high, size=(batch_count, m, k))),
50
+ A = wp.array3d(
51
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
39
52
  dtype=self.dtype,
40
53
  device=self.device,
54
+ requires_grad=True,
41
55
  )
42
- B = wp.array2d(
43
- np.ceil(np.random.uniform(low=low, high=high, size=(batch_count, k, n))),
56
+ B = wp.array3d(
57
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
44
58
  dtype=self.dtype,
45
59
  device=self.device,
60
+ requires_grad=True,
46
61
  )
47
- C = wp.array2d(
48
- np.ceil(np.random.uniform(low=low, high=high, size=(batch_count, m, n))),
62
+ C = wp.array3d(
63
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
49
64
  dtype=self.dtype,
50
65
  device=self.device,
66
+ requires_grad=True,
51
67
  )
52
- D = wp.array2d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device)
53
- adj_A = wp.array2d(np.zeros((batch_count, m, k)), dtype=self.dtype, device=self.device)
54
- adj_B = wp.array2d(np.zeros((batch_count, k, n)), dtype=self.dtype, device=self.device)
55
- adj_C = wp.array2d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device)
56
- adj_D = wp.array2d(np.ones((batch_count, m, n)), dtype=self.dtype, device=self.device)
57
- return A, B, C, D, adj_A, adj_B, adj_C, adj_D
68
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
69
+ return A, B, C, D
58
70
 
59
71
  def run_and_verify(self, m, n, k, batch_count, alpha, beta):
60
- A, B, C, D, adj_A, adj_B, adj_C, adj_D = self.alloc(m, n, k, batch_count)
72
+ A, B, C, D = self.alloc(m, n, k, batch_count)
73
+ ones = wp.zeros_like(D)
74
+ ones.fill_(1.0)
75
+
61
76
  if batch_count == 1:
62
- wp.matmul(A, B, C, D, alpha, beta, False, self.device)
77
+ tape = wp.Tape()
78
+ with tape:
79
+ wp.matmul(A, B, C, D, alpha, beta, False, self.device)
80
+ tape.backward(grads={D: ones})
81
+
63
82
  D_np = alpha * (A.numpy() @ B.numpy()) + beta * C.numpy()
64
83
  assert np.array_equal(D_np, D.numpy())
65
84
 
66
- wp.adj_matmul(A, B, C, adj_A, adj_B, adj_C, adj_D, alpha, beta, False, self.device)
67
- adj_A_np = alpha * np.matmul(adj_D.numpy(), B.numpy().transpose())
68
- adj_B_np = alpha * (A.numpy().transpose() @ adj_D.numpy())
69
- adj_C_np = beta * adj_D.numpy()
85
+ adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose())
86
+ adj_B_np = alpha * (A.numpy().transpose() @ ones.numpy())
87
+ adj_C_np = beta * ones.numpy()
70
88
 
71
- assert np.array_equal(adj_A_np, adj_A.numpy())
72
- assert np.array_equal(adj_B_np, adj_B.numpy())
73
- assert np.array_equal(adj_C_np, adj_C.numpy())
74
89
  else:
75
- wp.batched_matmul(A, B, C, D, alpha, beta, False, self.device)
90
+ tape = wp.Tape()
91
+ with tape:
92
+ wp.batched_matmul(A, B, C, D, alpha, beta, False, self.device)
93
+ tape.backward(grads={D: ones})
94
+
76
95
  D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
77
96
  assert np.array_equal(D_np, D.numpy())
78
97
 
79
- wp.adj_batched_matmul(A, B, C, adj_A, adj_B, adj_C, adj_D, alpha, beta, False, self.device)
80
- adj_A_np = alpha * np.matmul(adj_D.numpy(), B.numpy().transpose((0, 2, 1)))
81
- adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), adj_D.numpy())
82
- adj_C_np = beta * adj_D.numpy()
83
- assert np.array_equal(adj_A_np, adj_A.numpy())
84
- assert np.array_equal(adj_B_np, adj_B.numpy())
85
- assert np.array_equal(adj_C_np, adj_C.numpy())
98
+ adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
99
+ adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
100
+ adj_C_np = beta * ones.numpy()
101
+
102
+ assert np.array_equal(adj_A_np, A.grad.numpy())
103
+ assert np.array_equal(adj_B_np, B.grad.numpy())
104
+ assert np.array_equal(adj_C_np, C.grad.numpy())
86
105
 
87
106
  def run(self):
88
107
  Ms = [64, 128, 512]
@@ -100,17 +119,156 @@ class GemmTestbedRunner:
100
119
  self.run_and_verify(m, n, k, batch_count, alpha, beta)
101
120
 
102
121
 
122
+ class gemm_test_bed_runner_transpose:
123
+ def __init__(self, dtype, device):
124
+ self.dtype = dtype
125
+ self.device = device
126
+
127
+ def alloc(self, m, n, k, batch_count):
128
+ rng = np.random.default_rng(42)
129
+ low = -4.5
130
+ high = 3.5
131
+ if batch_count == 1:
132
+ A = wp.array2d(
133
+ np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
134
+ dtype=self.dtype,
135
+ device=self.device,
136
+ requires_grad=True,
137
+ )
138
+ B = wp.array2d(
139
+ np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
140
+ dtype=self.dtype,
141
+ device=self.device,
142
+ requires_grad=True,
143
+ )
144
+ C = wp.array2d(
145
+ np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
146
+ dtype=self.dtype,
147
+ device=self.device,
148
+ requires_grad=True,
149
+ )
150
+ D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
151
+ AT = wp.array2d(A.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
152
+ BT = wp.array2d(B.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
153
+ else:
154
+ A = wp.array3d(
155
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
156
+ dtype=self.dtype,
157
+ device=self.device,
158
+ requires_grad=True,
159
+ )
160
+ B = wp.array3d(
161
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
162
+ dtype=self.dtype,
163
+ device=self.device,
164
+ requires_grad=True,
165
+ )
166
+ C = wp.array3d(
167
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
168
+ dtype=self.dtype,
169
+ device=self.device,
170
+ requires_grad=True,
171
+ )
172
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
173
+ AT = wp.array3d(A.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
174
+ BT = wp.array3d(B.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
175
+ return A, B, C, D, AT, BT
176
+
177
+ def run_and_verify(self, m, n, k, batch_count, alpha, beta):
178
+ A, B, C1, D1, AT1, BT1 = self.alloc(m, n, k, batch_count)
179
+ C2 = wp.clone(C1)
180
+ C3 = wp.clone(C1)
181
+ D2 = wp.clone(D1)
182
+ D3 = wp.clone(D1)
183
+ AT2 = wp.clone(AT1)
184
+ BT2 = wp.clone(BT1)
185
+ ones1 = wp.zeros_like(D1)
186
+ ones1.fill_(1.0)
187
+ ones2 = wp.zeros_like(D2)
188
+ ones2.fill_(1.0)
189
+ ones3 = wp.zeros_like(D3)
190
+ ones3.fill_(1.0)
191
+
192
+ if batch_count == 1:
193
+ ATT1 = AT1.transpose([1, 0])
194
+ BTT1 = BT1.transpose([1, 0])
195
+ ATT2 = AT2.transpose([1, 0])
196
+ BTT2 = BT2.transpose([1, 0])
197
+ tape = wp.Tape()
198
+ with tape:
199
+ wp.matmul(A, BTT1, C1, D1, alpha, beta, False, self.device)
200
+ wp.matmul(ATT1, B, C2, D2, alpha, beta, False, self.device)
201
+ wp.matmul(ATT2, BTT2, C3, D3, alpha, beta, False, self.device)
202
+ tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
203
+
204
+ D_np = alpha * (A.numpy() @ B.numpy()) + beta * C1.numpy()
205
+ assert np.array_equal(D_np, D1.numpy())
206
+ assert np.array_equal(D_np, D2.numpy())
207
+ assert np.array_equal(D_np, D3.numpy())
208
+
209
+ adj_A_np = alpha * (ones1.numpy() @ B.numpy().transpose())
210
+ adj_B_np = alpha * (A.numpy().transpose() @ ones1.numpy())
211
+ adj_C_np = beta * ones1.numpy()
212
+
213
+ else:
214
+ ATT1 = AT1.transpose([0, 2, 1])
215
+ BTT1 = BT1.transpose([0, 2, 1])
216
+ ATT2 = AT2.transpose([0, 2, 1])
217
+ BTT2 = BT2.transpose([0, 2, 1])
218
+ tape = wp.Tape()
219
+ with tape:
220
+ wp.batched_matmul(A, BTT1, C1, D1, alpha, beta, False, self.device)
221
+ wp.batched_matmul(ATT1, B, C2, D2, alpha, beta, False, self.device)
222
+ wp.batched_matmul(ATT2, BTT2, C3, D3, alpha, beta, False, self.device)
223
+ tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
224
+
225
+ D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C1.numpy()
226
+ assert np.array_equal(D_np, D1.numpy())
227
+ assert np.array_equal(D_np, D2.numpy())
228
+ assert np.array_equal(D_np, D3.numpy())
229
+
230
+ adj_A_np = alpha * np.matmul(ones1.numpy(), B.numpy().transpose((0, 2, 1)))
231
+ adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones1.numpy())
232
+ adj_C_np = beta * ones1.numpy()
233
+
234
+ assert np.array_equal(adj_A_np, A.grad.numpy())
235
+ assert np.array_equal(adj_A_np, ATT1.grad.numpy())
236
+ assert np.array_equal(adj_A_np, ATT2.grad.numpy())
237
+ assert np.array_equal(adj_B_np, B.grad.numpy())
238
+ assert np.array_equal(adj_B_np, BTT1.grad.numpy())
239
+ assert np.array_equal(adj_B_np, BTT2.grad.numpy())
240
+ assert np.array_equal(adj_C_np, C1.grad.numpy())
241
+ assert np.array_equal(adj_C_np, C2.grad.numpy())
242
+ assert np.array_equal(adj_C_np, C3.grad.numpy())
243
+
244
+ def run(self):
245
+ m = 16
246
+ n = 32
247
+ k = 64
248
+ batch_counts = [1, 4]
249
+ beta = 1.0
250
+ alpha = 1.0
251
+
252
+ for batch_count in batch_counts:
253
+ self.run_and_verify(m, n, k, batch_count, alpha, beta)
254
+
255
+
103
256
  # NOTE: F16 tests are slow due to the performance of the reference numpy F16 matmuls performed on CPU.
104
257
  def test_f16(test, device):
105
- GemmTestbedRunner(wp.float16, device).run()
258
+ gemm_test_bed_runner(wp.float16, device).run()
259
+ gemm_test_bed_runner_transpose(wp.float16, device).run()
106
260
 
107
261
 
262
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
108
263
  def test_f32(test, device):
109
- GemmTestbedRunner(wp.float32, device).run()
264
+ gemm_test_bed_runner(wp.float32, device).run()
265
+ gemm_test_bed_runner_transpose(wp.float32, device).run()
110
266
 
111
267
 
268
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
112
269
  def test_f64(test, device):
113
- GemmTestbedRunner(wp.float64, device).run()
270
+ gemm_test_bed_runner(wp.float64, device).run()
271
+ gemm_test_bed_runner_transpose(wp.float64, device).run()
114
272
 
115
273
 
116
274
  @wp.kernel
@@ -119,20 +277,22 @@ def matrix_sum_kernel(arr: wp.array2d(dtype=float), loss: wp.array(dtype=float))
119
277
  wp.atomic_add(loss, 0, arr[i, j])
120
278
 
121
279
 
280
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
122
281
  def test_tape(test, device):
282
+ rng = np.random.default_rng(42)
123
283
  low = -4.5
124
284
  high = 3.5
125
285
  m = 64
126
286
  n = 128
127
287
  k = 256
128
288
  A = wp.array2d(
129
- np.ceil(np.random.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
289
+ np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
130
290
  )
131
291
  B = wp.array2d(
132
- np.ceil(np.random.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
292
+ np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
133
293
  )
134
294
  C = wp.array2d(
135
- np.ceil(np.random.uniform(low=low, high=high, size=(m, n))), dtype=float, device=device, requires_grad=True
295
+ np.ceil(rng.uniform(low=low, high=high, size=(m, n))), dtype=float, device=device, requires_grad=True
136
296
  )
137
297
  D = wp.array2d(np.zeros((m, n)), dtype=float, device=device, requires_grad=True)
138
298
  loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
@@ -145,6 +305,7 @@ def test_tape(test, device):
145
305
 
146
306
  tape.backward(loss=loss)
147
307
  A_grad = A.grad.numpy()
308
+ tape.reset()
148
309
 
149
310
  # test adjoint
150
311
  D.grad = wp.array2d(np.ones((m, n)), dtype=float, device=device)
@@ -156,17 +317,19 @@ def test_tape(test, device):
156
317
  assert_array_equal(A.grad, wp.zeros_like(A))
157
318
 
158
319
 
320
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
159
321
  def test_operator(test, device):
322
+ rng = np.random.default_rng(42)
160
323
  low = -4.5
161
324
  high = 3.5
162
325
  m = 64
163
326
  n = 128
164
327
  k = 256
165
328
  A = wp.array2d(
166
- np.ceil(np.random.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
329
+ np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
167
330
  )
168
331
  B = wp.array2d(
169
- np.ceil(np.random.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
332
+ np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
170
333
  )
171
334
  loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
172
335
 
@@ -180,7 +343,6 @@ def test_operator(test, device):
180
343
 
181
344
  # test adjoint
182
345
  D.grad = wp.array2d(np.ones((m, n)), dtype=float, device=device)
183
- # deep copy, needed because transpose data is not contiguous
184
346
  B_transpose = wp.array2d(B.transpose().numpy(), dtype=float, device=device)
185
347
 
186
348
  adj_A = D.grad @ B_transpose
@@ -191,28 +353,101 @@ def test_operator(test, device):
191
353
  assert_array_equal(A.grad, wp.zeros_like(A))
192
354
 
193
355
 
194
- def register(parent):
195
- devices = [d for d in get_test_devices()]
356
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
357
+ def test_large_batch_count(test, device):
358
+ rng = np.random.default_rng(42)
359
+ low = -4.5
360
+ high = 3.5
361
+ m = 2
362
+ n = 3
363
+ k = 4
364
+ batch_count = 65535 * 2 + int(65535 / 2)
365
+ A = wp.array3d(
366
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
367
+ dtype=float,
368
+ device=device,
369
+ requires_grad=True,
370
+ )
371
+ B = wp.array3d(
372
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
373
+ dtype=float,
374
+ device=device,
375
+ requires_grad=True,
376
+ )
377
+ C = wp.array3d(
378
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
379
+ dtype=float,
380
+ device=device,
381
+ requires_grad=True,
382
+ )
383
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=float, device=device, requires_grad=True)
384
+ ones = wp.zeros_like(D)
385
+ ones.fill_(1.0)
386
+
387
+ alpha = 1.0
388
+ beta = 1.0
196
389
 
197
- class TestMatmul(parent):
198
- pass
390
+ tape = wp.Tape()
391
+ with tape:
392
+ wp.batched_matmul(A, B, C, D, alpha=alpha, beta=beta, allow_tf32x3_arith=False, device=device)
393
+ tape.backward(grads={D: ones})
199
394
 
200
- if devices:
201
- # check if CUTLASS is available
202
- from warp.context import runtime
395
+ D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
396
+ assert np.array_equal(D_np, D.numpy())
397
+
398
+ adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
399
+ adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
400
+ adj_C_np = beta * ones.numpy()
401
+
402
+ assert np.array_equal(adj_A_np, A.grad.numpy())
403
+ assert np.array_equal(adj_B_np, B.grad.numpy())
404
+ assert np.array_equal(adj_C_np, C.grad.numpy())
405
+
406
+
407
+ def test_adjoint_accumulation(test, device):
408
+ a_np = np.ones(shape=(2,3))
409
+ b_np = np.ones(shape=(3,2))
410
+ c_np = np.zeros(shape=(2,2))
411
+ d_np = np.zeros(shape=(2,2))
412
+
413
+ a_wp = wp.from_numpy(a_np, dtype=float, requires_grad=True)
414
+ b_wp = wp.from_numpy(b_np, dtype=float, requires_grad=True)
415
+ c_wp = wp.from_numpy(c_np, dtype=float, requires_grad=True)
416
+ d1_wp = wp.from_numpy(d_np, dtype=float, requires_grad=True)
417
+ d2_wp = wp.from_numpy(d_np, dtype=float, requires_grad=True)
418
+
419
+ tape = wp.Tape()
420
+
421
+ with tape:
422
+ wp.matmul(a_wp, b_wp, c_wp, d1_wp, alpha=1.0, beta=1.0)
423
+ wp.matmul(a_wp, b_wp, d1_wp, d2_wp, alpha=1.0, beta=1.0)
424
+
425
+ d_grad = wp.zeros_like(d2_wp)
426
+ d_grad.fill_(1.)
427
+ grads = {d2_wp : d_grad}
428
+ tape.backward(grads=grads)
429
+
430
+ assert np.array_equal(a_wp.grad.numpy(), 4.0 * np.ones(shape=(2,3)))
431
+ assert np.array_equal(b_wp.grad.numpy(), 4.0 * np.ones(shape=(3,2)))
432
+ assert np.array_equal(c_wp.grad.numpy(), np.ones(shape=(2,2)))
433
+
434
+
435
+ devices = get_test_devices()
436
+
437
+
438
+ class TestMatmul(unittest.TestCase):
439
+ pass
203
440
 
204
- if runtime.core.is_cutlass_enabled():
205
- # add_function_test(TestMatmul, "test_f16", test_f16, devices=devices)
206
- add_function_test(TestMatmul, "test_f32", test_f32, devices=devices)
207
- add_function_test(TestMatmul, "test_f64", test_f64, devices=devices)
208
- add_function_test(TestMatmul, "test_tape", test_tape, devices=devices)
209
- add_function_test(TestMatmul, "test_operator", test_operator, devices=devices)
210
- else:
211
- print("Skipping matmul tests because CUTLASS is not supported in this build")
212
441
 
213
- return TestMatmul
442
+ # add_function_test(TestMatmul, "test_f16", test_f16, devices=devices)
443
+ add_function_test(TestMatmul, "test_f32", test_f32, devices=devices)
444
+ add_function_test(TestMatmul, "test_f64", test_f64, devices=devices)
445
+ add_function_test(TestMatmul, "test_tape", test_tape, devices=devices)
446
+ add_function_test(TestMatmul, "test_operator", test_operator, devices=devices)
447
+ add_function_test(TestMatmul, "test_large_batch_count", test_large_batch_count, devices=devices)
448
+ add_function_test(TestMatmul, "test_adjoint_accumulation", test_adjoint_accumulation, devices=devices)
214
449
 
215
450
 
216
451
  if __name__ == "__main__":
217
- c = register(unittest.TestCase)
452
+ wp.build.clear_kernel_cache()
218
453
  unittest.main(verbosity=2, failfast=False)