warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.so +0 -0
  57. warp/bin/warp.so +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/fem/field/discrete_field.py +0 -80
  257. warp/fem/space/nodal_function_space.py +0 -233
  258. warp/tests/test_all.py +0 -223
  259. warp/tests/test_array_scan.py +0 -60
  260. warp/tests/test_base.py +0 -208
  261. warp/tests/test_unresolved_func.py +0 -7
  262. warp/tests/test_unresolved_symbol.py +0 -7
  263. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  264. warp_lang-1.0.0b2.dist-info/RECORD +0 -378
  265. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  266. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  267. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  268. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  269. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,410 @@
1
+ # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
10
+ import numpy as np
11
+
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
14
+
15
+ wp.init()
16
+
17
+ from warp.context import runtime # noqa: E402
18
+
19
+
20
+ class gemm_test_bed_runner:
21
+ def __init__(self, dtype, device):
22
+ self.dtype = dtype
23
+ self.device = device
24
+
25
+ def alloc(self, m, n, k, batch_count):
26
+ rng = np.random.default_rng(42)
27
+ low = -4.5
28
+ high = 3.5
29
+ if batch_count == 1:
30
+ A = wp.array2d(
31
+ np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
32
+ dtype=self.dtype,
33
+ device=self.device,
34
+ requires_grad=True,
35
+ )
36
+ B = wp.array2d(
37
+ np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
38
+ dtype=self.dtype,
39
+ device=self.device,
40
+ requires_grad=True,
41
+ )
42
+ C = wp.array2d(
43
+ np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
44
+ dtype=self.dtype,
45
+ device=self.device,
46
+ requires_grad=True,
47
+ )
48
+ D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
49
+ else:
50
+ A = wp.array3d(
51
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
52
+ dtype=self.dtype,
53
+ device=self.device,
54
+ requires_grad=True,
55
+ )
56
+ B = wp.array3d(
57
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
58
+ dtype=self.dtype,
59
+ device=self.device,
60
+ requires_grad=True,
61
+ )
62
+ C = wp.array3d(
63
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
64
+ dtype=self.dtype,
65
+ device=self.device,
66
+ requires_grad=True,
67
+ )
68
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
69
+ return A, B, C, D
70
+
71
+ def run_and_verify(self, m, n, k, batch_count, alpha, beta):
72
+ A, B, C, D = self.alloc(m, n, k, batch_count)
73
+ ones = wp.zeros_like(D)
74
+ ones.fill_(1.0)
75
+
76
+ if batch_count == 1:
77
+ tape = wp.Tape()
78
+ with tape:
79
+ wp.matmul(A, B, C, D, alpha, beta, False, self.device)
80
+ tape.backward(grads={D: ones})
81
+
82
+ D_np = alpha * (A.numpy() @ B.numpy()) + beta * C.numpy()
83
+ assert np.array_equal(D_np, D.numpy())
84
+
85
+ adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose())
86
+ adj_B_np = alpha * (A.numpy().transpose() @ ones.numpy())
87
+ adj_C_np = beta * ones.numpy()
88
+
89
+ else:
90
+ tape = wp.Tape()
91
+ with tape:
92
+ wp.batched_matmul(A, B, C, D, alpha, beta, False, self.device)
93
+ tape.backward(grads={D: ones})
94
+
95
+ D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
96
+ assert np.array_equal(D_np, D.numpy())
97
+
98
+ adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
99
+ adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
100
+ adj_C_np = beta * ones.numpy()
101
+
102
+ assert np.array_equal(adj_A_np, A.grad.numpy())
103
+ assert np.array_equal(adj_B_np, B.grad.numpy())
104
+ assert np.array_equal(adj_C_np, C.grad.numpy())
105
+
106
+ def run(self):
107
+ Ms = [8]
108
+ Ns = [16]
109
+ Ks = [32]
110
+ batch_counts = [1]
111
+ betas = [1.0]
112
+ alpha = 1.0
113
+
114
+ for batch_count in batch_counts:
115
+ for m in Ms:
116
+ for n in Ns:
117
+ for k in Ks:
118
+ for beta in betas:
119
+ self.run_and_verify(m, n, k, batch_count, alpha, beta)
120
+
121
+
122
+ class gemm_test_bed_runner_transpose:
123
+ def __init__(self, dtype, device):
124
+ self.dtype = dtype
125
+ self.device = device
126
+
127
+ def alloc(self, m, n, k, batch_count):
128
+ rng = np.random.default_rng(42)
129
+ low = -4.5
130
+ high = 3.5
131
+ if batch_count == 1:
132
+ A = wp.array2d(
133
+ np.ceil(rng.uniform(low=low, high=high, size=(m, k))),
134
+ dtype=self.dtype,
135
+ device=self.device,
136
+ requires_grad=True,
137
+ )
138
+ B = wp.array2d(
139
+ np.ceil(rng.uniform(low=low, high=high, size=(k, n))),
140
+ dtype=self.dtype,
141
+ device=self.device,
142
+ requires_grad=True,
143
+ )
144
+ C = wp.array2d(
145
+ np.ceil(rng.uniform(low=low, high=high, size=(m, n))),
146
+ dtype=self.dtype,
147
+ device=self.device,
148
+ requires_grad=True,
149
+ )
150
+ D = wp.array2d(np.zeros((m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
151
+ AT = wp.array2d(A.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
152
+ BT = wp.array2d(B.numpy().transpose([1, 0]), dtype=self.dtype, device=self.device, requires_grad=True)
153
+ else:
154
+ A = wp.array3d(
155
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
156
+ dtype=self.dtype,
157
+ device=self.device,
158
+ requires_grad=True,
159
+ )
160
+ B = wp.array3d(
161
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
162
+ dtype=self.dtype,
163
+ device=self.device,
164
+ requires_grad=True,
165
+ )
166
+ C = wp.array3d(
167
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
168
+ dtype=self.dtype,
169
+ device=self.device,
170
+ requires_grad=True,
171
+ )
172
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=self.dtype, device=self.device, requires_grad=True)
173
+ AT = wp.array3d(A.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
174
+ BT = wp.array3d(B.numpy().transpose([0, 2, 1]), dtype=self.dtype, device=self.device, requires_grad=True)
175
+ return A, B, C, D, AT, BT
176
+
177
+ def run_and_verify(self, m, n, k, batch_count, alpha, beta):
178
+ A, B, C1, D1, AT1, BT1 = self.alloc(m, n, k, batch_count)
179
+ C2 = wp.clone(C1)
180
+ C3 = wp.clone(C1)
181
+ D2 = wp.clone(D1)
182
+ D3 = wp.clone(D1)
183
+ AT2 = wp.clone(AT1)
184
+ BT2 = wp.clone(BT1)
185
+ ones1 = wp.zeros_like(D1)
186
+ ones1.fill_(1.0)
187
+ ones2 = wp.zeros_like(D2)
188
+ ones2.fill_(1.0)
189
+ ones3 = wp.zeros_like(D3)
190
+ ones3.fill_(1.0)
191
+
192
+ if batch_count == 1:
193
+ ATT1 = AT1.transpose([1, 0])
194
+ BTT1 = BT1.transpose([1, 0])
195
+ ATT2 = AT2.transpose([1, 0])
196
+ BTT2 = BT2.transpose([1, 0])
197
+ tape = wp.Tape()
198
+ with tape:
199
+ wp.matmul(A, BTT1, C1, D1, alpha, beta, False, self.device)
200
+ wp.matmul(ATT1, B, C2, D2, alpha, beta, False, self.device)
201
+ wp.matmul(ATT2, BTT2, C3, D3, alpha, beta, False, self.device)
202
+ tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
203
+
204
+ D_np = alpha * (A.numpy() @ B.numpy()) + beta * C1.numpy()
205
+ assert np.array_equal(D_np, D1.numpy())
206
+ assert np.array_equal(D_np, D2.numpy())
207
+ assert np.array_equal(D_np, D3.numpy())
208
+
209
+ adj_A_np = alpha * (ones1.numpy() @ B.numpy().transpose())
210
+ adj_B_np = alpha * (A.numpy().transpose() @ ones1.numpy())
211
+ adj_C_np = beta * ones1.numpy()
212
+
213
+ else:
214
+ ATT1 = AT1.transpose([0, 2, 1])
215
+ BTT1 = BT1.transpose([0, 2, 1])
216
+ ATT2 = AT2.transpose([0, 2, 1])
217
+ BTT2 = BT2.transpose([0, 2, 1])
218
+ tape = wp.Tape()
219
+ with tape:
220
+ wp.batched_matmul(A, BTT1, C1, D1, alpha, beta, False, self.device)
221
+ wp.batched_matmul(ATT1, B, C2, D2, alpha, beta, False, self.device)
222
+ wp.batched_matmul(ATT2, BTT2, C3, D3, alpha, beta, False, self.device)
223
+ tape.backward(grads={D1: ones1, D2: ones2, D3: ones3})
224
+
225
+ D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C1.numpy()
226
+ assert np.array_equal(D_np, D1.numpy())
227
+ assert np.array_equal(D_np, D2.numpy())
228
+ assert np.array_equal(D_np, D3.numpy())
229
+
230
+ adj_A_np = alpha * np.matmul(ones1.numpy(), B.numpy().transpose((0, 2, 1)))
231
+ adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones1.numpy())
232
+ adj_C_np = beta * ones1.numpy()
233
+
234
+ assert np.array_equal(adj_A_np, A.grad.numpy())
235
+ assert np.array_equal(adj_A_np, ATT1.grad.numpy())
236
+ assert np.array_equal(adj_A_np, ATT2.grad.numpy())
237
+ assert np.array_equal(adj_B_np, B.grad.numpy())
238
+ assert np.array_equal(adj_B_np, BTT1.grad.numpy())
239
+ assert np.array_equal(adj_B_np, BTT2.grad.numpy())
240
+ assert np.array_equal(adj_C_np, C1.grad.numpy())
241
+ assert np.array_equal(adj_C_np, C2.grad.numpy())
242
+ assert np.array_equal(adj_C_np, C3.grad.numpy())
243
+
244
+ def run(self):
245
+ m = 8
246
+ n = 16
247
+ k = 32
248
+ batch_counts = [1, 4]
249
+ beta = 1.0
250
+ alpha = 1.0
251
+
252
+ for batch_count in batch_counts:
253
+ self.run_and_verify(m, n, k, batch_count, alpha, beta)
254
+
255
+
256
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
257
+ def test_f32(test, device):
258
+ gemm_test_bed_runner(wp.float32, device).run()
259
+ gemm_test_bed_runner_transpose(wp.float32, device).run()
260
+
261
+
262
+ @wp.kernel
263
+ def matrix_sum_kernel(arr: wp.array2d(dtype=float), loss: wp.array(dtype=float)):
264
+ i, j = wp.tid()
265
+ wp.atomic_add(loss, 0, arr[i, j])
266
+
267
+
268
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
269
+ def test_tape(test, device):
270
+ rng = np.random.default_rng(42)
271
+ low = -4.5
272
+ high = 3.5
273
+ m = 8
274
+ n = 16
275
+ k = 32
276
+ A = wp.array2d(
277
+ np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
278
+ )
279
+ B = wp.array2d(
280
+ np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
281
+ )
282
+ C = wp.array2d(
283
+ np.ceil(rng.uniform(low=low, high=high, size=(m, n))), dtype=float, device=device, requires_grad=True
284
+ )
285
+ D = wp.array2d(np.zeros((m, n)), dtype=float, device=device, requires_grad=True)
286
+ loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
287
+
288
+ # test tape
289
+ tape = wp.Tape()
290
+ with tape:
291
+ wp.matmul(A, B, C, D, device=device)
292
+ wp.launch(matrix_sum_kernel, dim=(m, n), inputs=[D, loss], device=device)
293
+
294
+ tape.backward(loss=loss)
295
+ A_grad = A.grad.numpy()
296
+ tape.reset()
297
+
298
+ # test adjoint
299
+ D.grad = wp.array2d(np.ones((m, n)), dtype=float, device=device)
300
+ wp.adj_matmul(A, B, C, A.grad, B.grad, C.grad, D.grad, device=device)
301
+ assert_np_equal(A_grad, A.grad.numpy())
302
+
303
+ # test zero
304
+ tape.zero()
305
+ assert_array_equal(A.grad, wp.zeros_like(A))
306
+
307
+
308
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
309
+ def test_operator(test, device):
310
+ rng = np.random.default_rng(42)
311
+ low = -4.5
312
+ high = 3.5
313
+ m = 8
314
+ n = 16
315
+ k = 32
316
+ A = wp.array2d(
317
+ np.ceil(rng.uniform(low=low, high=high, size=(m, k))), dtype=float, device=device, requires_grad=True
318
+ )
319
+ B = wp.array2d(
320
+ np.ceil(rng.uniform(low=low, high=high, size=(k, n))), dtype=float, device=device, requires_grad=True
321
+ )
322
+ loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
323
+
324
+ # test tape
325
+ tape = wp.Tape()
326
+ with tape:
327
+ D = A @ B
328
+ wp.launch(matrix_sum_kernel, dim=(m, n), inputs=[D, loss], device=device)
329
+
330
+ tape.backward(loss=loss)
331
+
332
+ # test adjoint
333
+ D.grad = wp.array2d(np.ones((m, n)), dtype=float, device=device)
334
+ B_transpose = wp.array2d(B.transpose().numpy(), dtype=float, device=device)
335
+
336
+ adj_A = D.grad @ B_transpose
337
+ assert_array_equal(adj_A, A.grad)
338
+
339
+ # test zero
340
+ tape.zero()
341
+ assert_array_equal(A.grad, wp.zeros_like(A))
342
+
343
+
344
+ @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
345
+ def test_large_batch_count(test, device):
346
+ rng = np.random.default_rng(42)
347
+ low = -4.5
348
+ high = 3.5
349
+ m = 2
350
+ n = 3
351
+ k = 4
352
+ batch_count = 65535 * 2 + int(65535 / 2)
353
+ A = wp.array3d(
354
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, k))),
355
+ dtype=float,
356
+ device=device,
357
+ requires_grad=True,
358
+ )
359
+ B = wp.array3d(
360
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, k, n))),
361
+ dtype=float,
362
+ device=device,
363
+ requires_grad=True,
364
+ )
365
+ C = wp.array3d(
366
+ np.ceil(rng.uniform(low=low, high=high, size=(batch_count, m, n))),
367
+ dtype=float,
368
+ device=device,
369
+ requires_grad=True,
370
+ )
371
+ D = wp.array3d(np.zeros((batch_count, m, n)), dtype=float, device=device, requires_grad=True)
372
+ ones = wp.zeros_like(D)
373
+ ones.fill_(1.0)
374
+
375
+ alpha = 1.0
376
+ beta = 1.0
377
+
378
+ tape = wp.Tape()
379
+ with tape:
380
+ wp.batched_matmul(A, B, C, D, alpha=alpha, beta=beta, allow_tf32x3_arith=False, device=device)
381
+ tape.backward(grads={D: ones})
382
+
383
+ D_np = alpha * np.matmul(A.numpy(), B.numpy()) + beta * C.numpy()
384
+ assert np.array_equal(D_np, D.numpy())
385
+
386
+ adj_A_np = alpha * np.matmul(ones.numpy(), B.numpy().transpose((0, 2, 1)))
387
+ adj_B_np = alpha * np.matmul(A.numpy().transpose((0, 2, 1)), ones.numpy())
388
+ adj_C_np = beta * ones.numpy()
389
+
390
+ assert np.array_equal(adj_A_np, A.grad.numpy())
391
+ assert np.array_equal(adj_B_np, B.grad.numpy())
392
+ assert np.array_equal(adj_C_np, C.grad.numpy())
393
+
394
+
395
+ devices = get_test_devices()
396
+
397
+
398
+ class TestMatmulLite(unittest.TestCase):
399
+ pass
400
+
401
+
402
+ add_function_test(TestMatmulLite, "test_f32", test_f32, devices=devices)
403
+ add_function_test(TestMatmulLite, "test_tape", test_tape, devices=devices)
404
+ add_function_test(TestMatmulLite, "test_operator", test_operator, devices=devices)
405
+ add_function_test(TestMatmulLite, "test_large_batch_count", test_large_batch_count, devices=devices)
406
+
407
+
408
+ if __name__ == "__main__":
409
+ wp.build.clear_kernel_cache()
410
+ unittest.main(verbosity=2, failfast=False)
warp/tests/test_mesh.py CHANGED
@@ -10,8 +10,7 @@ import unittest
10
10
  import numpy as np
11
11
 
12
12
  import warp as wp
13
- from warp.tests.test_base import *
14
-
13
+ from warp.tests.unittest_utils import *
15
14
 
16
15
  # fmt: off
17
16
 
@@ -223,9 +222,9 @@ def query_ray_kernel(
223
222
 
224
223
 
225
224
  def test_mesh_query_ray(test, device):
226
- points = wp.array(POINT_POSITIONS, dtype=wp.vec3)
225
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
227
226
 
228
- indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int)
227
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
229
228
  mesh = wp.Mesh(points=points, indices=indices)
230
229
  expected_sign = -1.0
231
230
  wp.launch(
@@ -235,9 +234,10 @@ def test_mesh_query_ray(test, device):
235
234
  mesh.id,
236
235
  expected_sign,
237
236
  ],
237
+ device=device,
238
238
  )
239
239
 
240
- indices = wp.array(LEFT_HANDED_FACE_VERTEX_INDICES, dtype=int)
240
+ indices = wp.array(LEFT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
241
241
  mesh = wp.Mesh(points=points, indices=indices)
242
242
  expected_sign = 1.0
243
243
  wp.launch(
@@ -247,40 +247,78 @@ def test_mesh_query_ray(test, device):
247
247
  mesh.id,
248
248
  expected_sign,
249
249
  ],
250
+ device=device,
250
251
  )
251
252
 
252
253
 
253
254
  def test_mesh_refit_graph(test, device):
254
- points = wp.array(POINT_POSITIONS, dtype=wp.vec3)
255
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
255
256
 
256
- indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int)
257
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
257
258
  mesh = wp.Mesh(points=points, indices=indices)
258
259
 
259
- wp.capture_begin()
260
-
261
- mesh.refit()
262
-
263
- graph = wp.capture_end()
260
+ wp.capture_begin(device, force_module_load=False)
261
+ try:
262
+ mesh.refit()
263
+ finally:
264
+ graph = wp.capture_end(device)
264
265
 
265
266
  # replay
266
267
  num_iters = 10
267
268
  for _ in range(num_iters):
268
269
  wp.capture_launch(graph)
269
270
 
271
+ wp.synchronize_device(device)
272
+
273
+
274
+ def test_mesh_exceptions(test, device):
275
+ # points and indices must be on same device
276
+ with test.assertRaises(RuntimeError):
277
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device="cpu")
278
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
279
+ wp.Mesh(points=points, indices=indices)
280
+
281
+ # points must be vec3
282
+ with test.assertRaises(RuntimeError):
283
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3d, device=device)
284
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
285
+ wp.Mesh(points=points, indices=indices)
286
+
287
+ # velocities must be vec3
288
+ with test.assertRaises(RuntimeError):
289
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
290
+ velocities = wp.zeros(points.shape, dtype=wp.vec3d, device=device)
291
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
292
+ wp.Mesh(points=points, indices=indices, velocities=velocities)
293
+
294
+ # indices must be int32
295
+ with test.assertRaises(RuntimeError):
296
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
297
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=wp.int64, device=device)
298
+ wp.Mesh(points=points, indices=indices)
299
+
300
+ # indices must be 1d
301
+ with test.assertRaises(RuntimeError):
302
+ points = wp.array(POINT_POSITIONS, dtype=wp.vec3, device=device)
303
+ indices = wp.array(RIGHT_HANDED_FACE_VERTEX_INDICES, dtype=int, device=device)
304
+ indices = indices.reshape((3, -1))
305
+ wp.Mesh(points=points, indices=indices)
306
+
307
+
308
+ devices = get_test_devices()
309
+
270
310
 
271
- def register(parent):
272
- devices = get_test_devices()
311
+ class TestMesh(unittest.TestCase):
312
+ pass
273
313
 
274
- class TestMesh(parent):
275
- pass
276
314
 
277
- add_function_test(TestMesh, "test_mesh_read_properties", test_mesh_read_properties, devices=devices)
278
- add_function_test(TestMesh, "test_mesh_query_point", test_mesh_query_point, devices=devices)
279
- add_function_test(TestMesh, "test_mesh_query_ray", test_mesh_query_ray, devices=devices)
280
- add_function_test(TestMesh, "test_mesh_refit_graph", test_mesh_refit_graph, devices=wp.get_cuda_devices())
281
- return TestMesh
315
+ add_function_test(TestMesh, "test_mesh_read_properties", test_mesh_read_properties, devices=devices)
316
+ add_function_test(TestMesh, "test_mesh_query_point", test_mesh_query_point, devices=devices)
317
+ add_function_test(TestMesh, "test_mesh_query_ray", test_mesh_query_ray, devices=devices)
318
+ add_function_test(TestMesh, "test_mesh_refit_graph", test_mesh_refit_graph, devices=get_unique_cuda_test_devices())
319
+ add_function_test(TestMesh, "test_mesh_exceptions", test_mesh_exceptions, devices=get_unique_cuda_test_devices())
282
320
 
283
321
 
284
322
  if __name__ == "__main__":
285
- _ = register(unittest.TestCase)
323
+ wp.build.clear_kernel_cache()
286
324
  unittest.main(verbosity=2)
@@ -5,10 +5,12 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import unittest
9
+
8
10
  import numpy as np
9
11
 
10
12
  import warp as wp
11
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
12
14
 
13
15
  wp.init()
14
16
 
@@ -96,7 +98,6 @@ def test_compute_bounds(test, device):
96
98
 
97
99
  lower_view = lowers.numpy()
98
100
  upper_view = uppers.numpy()
99
- wp.synchronize()
100
101
 
101
102
  # Confirm the bounds of each triangle are correct.
102
103
  test.assertTrue(lower_view[0][0] == 0)
@@ -148,8 +149,6 @@ def test_mesh_query_aabb_count_overlap(test, device):
148
149
  device=device,
149
150
  )
150
151
 
151
- wp.synchronize()
152
-
153
152
  view = counts.numpy()
154
153
 
155
154
  # 2 triangles that share a vertex having overlapping AABBs.
@@ -188,8 +187,6 @@ def test_mesh_query_aabb_count_nonoverlap(test, device):
188
187
  device=device,
189
188
  )
190
189
 
191
- wp.synchronize()
192
-
193
190
  view = counts.numpy()
194
191
 
195
192
  # AABB query only returns one triangle at a time, the triangles are not close enough to overlap.
@@ -197,29 +194,28 @@ def test_mesh_query_aabb_count_nonoverlap(test, device):
197
194
  test.assertTrue(c == 1)
198
195
 
199
196
 
200
- def register(parent):
201
- devices = get_test_devices()
197
+ devices = get_test_devices()
202
198
 
203
- class TestMeshQueryAABBMethods(parent):
204
- pass
205
199
 
206
- add_function_test(TestMeshQueryAABBMethods, "test_compute_bounds", test_compute_bounds, devices=devices)
207
- add_function_test(
208
- TestMeshQueryAABBMethods,
209
- "test_mesh_query_aabb_count_overlap",
210
- test_mesh_query_aabb_count_overlap,
211
- devices=devices,
212
- )
213
- add_function_test(
214
- TestMeshQueryAABBMethods,
215
- "test_mesh_query_aabb_count_nonoverlap",
216
- test_mesh_query_aabb_count_nonoverlap,
217
- devices=devices,
218
- )
200
+ class TestMeshQueryAABBMethods(unittest.TestCase):
201
+ pass
202
+
219
203
 
220
- return TestMeshQueryAABBMethods
204
+ add_function_test(TestMeshQueryAABBMethods, "test_compute_bounds", test_compute_bounds, devices=devices)
205
+ add_function_test(
206
+ TestMeshQueryAABBMethods,
207
+ "test_mesh_query_aabb_count_overlap",
208
+ test_mesh_query_aabb_count_overlap,
209
+ devices=devices,
210
+ )
211
+ add_function_test(
212
+ TestMeshQueryAABBMethods,
213
+ "test_mesh_query_aabb_count_nonoverlap",
214
+ test_mesh_query_aabb_count_nonoverlap,
215
+ devices=devices,
216
+ )
221
217
 
222
218
 
223
219
  if __name__ == "__main__":
224
- c = register(unittest.TestCase)
220
+ wp.build.clear_kernel_cache()
225
221
  unittest.main(verbosity=2)