warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.so +0 -0
  57. warp/bin/warp.so +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/fem/field/discrete_field.py +0 -80
  257. warp/fem/space/nodal_function_space.py +0 -233
  258. warp/tests/test_all.py +0 -223
  259. warp/tests/test_array_scan.py +0 -60
  260. warp/tests/test_base.py +0 -208
  261. warp/tests/test_unresolved_func.py +0 -7
  262. warp/tests/test_unresolved_symbol.py +0 -7
  263. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  264. warp_lang-1.0.0b2.dist-info/RECORD +0 -378
  265. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  266. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  267. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  268. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  269. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,143 @@
1
+ import unittest
2
+
3
+ import numpy as np
4
+
5
+ import warp as wp
6
+ from warp.tests.unittest_utils import *
7
+
8
+ wp.init()
9
+
10
+
11
+ def test_basic(test, device):
12
+ snippet = """
13
+ out[tid] = a * x[tid] + y[tid];
14
+ """
15
+ adj_snippet = """
16
+ adj_a = x[tid] * adj_out[tid];
17
+ adj_x[tid] = a * adj_out[tid];
18
+ adj_y[tid] = adj_out[tid];
19
+ """
20
+
21
+ @wp.func_native(snippet, adj_snippet)
22
+ def saxpy(
23
+ a: wp.float32,
24
+ x: wp.array(dtype=wp.float32),
25
+ y: wp.array(dtype=wp.float32),
26
+ out: wp.array(dtype=wp.float32),
27
+ tid: int,
28
+ ):
29
+ ...
30
+
31
+ @wp.kernel
32
+ def saxpy_cu(
33
+ a: wp.float32, x: wp.array(dtype=wp.float32), y: wp.array(dtype=wp.float32), out: wp.array(dtype=wp.float32)
34
+ ):
35
+ tid = wp.tid()
36
+ saxpy(a, x, y, out, tid)
37
+
38
+ @wp.kernel
39
+ def saxpy_py(
40
+ a: wp.float32, x: wp.array(dtype=wp.float32), y: wp.array(dtype=wp.float32), out: wp.array(dtype=wp.float32)
41
+ ):
42
+ tid = wp.tid()
43
+ out[tid] = a * x[tid] + y[tid]
44
+
45
+ N = 128
46
+
47
+ a1 = 2.0
48
+ x1 = wp.array(np.arange(N, dtype=np.float32), dtype=wp.float32, device=device, requires_grad=True)
49
+ y1 = wp.zeros_like(x1)
50
+ out1 = wp.array(np.arange(N, dtype=np.float32), dtype=wp.float32, device=device)
51
+ adj_out1 = wp.array(np.ones(N, dtype=np.float32), dtype=wp.float32, device=device)
52
+
53
+ a2 = 2.0
54
+ x2 = wp.array(np.arange(N, dtype=np.float32), dtype=wp.float32, device=device, requires_grad=True)
55
+ y2 = wp.zeros_like(x2)
56
+ out2 = wp.array(np.arange(N, dtype=np.float32), dtype=wp.float32, device=device)
57
+ adj_out2 = wp.array(np.ones(N, dtype=np.float32), dtype=wp.float32, device=device)
58
+
59
+ tape = wp.Tape()
60
+
61
+ with tape:
62
+ wp.launch(kernel=saxpy_cu, dim=N, inputs=[a1, x1, y1], outputs=[out1], device=device)
63
+ wp.launch(kernel=saxpy_py, dim=N, inputs=[a2, x2, y2], outputs=[out2], device=device)
64
+
65
+ tape.backward(grads={out1: adj_out1, out2: adj_out2})
66
+
67
+ # test forward snippet
68
+ assert_np_equal(out1.numpy(), out2.numpy())
69
+
70
+ # test backward snippet
71
+ assert_np_equal(x1.grad.numpy(), a1 * np.ones(N, dtype=np.float32))
72
+ assert_np_equal(x1.grad.numpy(), x2.grad.numpy())
73
+
74
+ assert_np_equal(y1.grad.numpy(), np.ones(N, dtype=np.float32))
75
+ assert_np_equal(y1.grad.numpy(), y2.grad.numpy())
76
+
77
+
78
+ def test_shared_memory(test, device):
79
+ snippet = """
80
+ __shared__ int s[128];
81
+
82
+ s[tid] = d[tid];
83
+ __syncthreads();
84
+ d[tid] = s[N - tid - 1];
85
+ """
86
+
87
+ @wp.func_native(snippet)
88
+ def reverse(d: wp.array(dtype=int), N: int, tid: int):
89
+ return
90
+
91
+ @wp.kernel
92
+ def reverse_kernel(d: wp.array(dtype=int), N: int):
93
+ tid = wp.tid()
94
+ reverse(d, N, tid)
95
+
96
+ N = 128
97
+ x = wp.array(np.arange(N, dtype=int), dtype=int, device=device)
98
+ y = np.arange(127, -1, -1, dtype=int)
99
+
100
+ wp.launch(kernel=reverse_kernel, dim=N, inputs=[x, N], device=device)
101
+
102
+ assert_np_equal(x.numpy(), y)
103
+
104
+
105
+ def test_cpu_snippet(test, device):
106
+ snippet = """
107
+ int inc = 1;
108
+ out[tid] = x[tid] + inc;
109
+ """
110
+
111
+ @wp.func_native(snippet)
112
+ def increment_snippet(
113
+ x: wp.array(dtype=wp.int32),
114
+ out: wp.array(dtype=wp.int32),
115
+ tid: int,
116
+ ):
117
+ ...
118
+
119
+ @wp.kernel
120
+ def increment(x: wp.array(dtype=wp.int32), out: wp.array(dtype=wp.int32)):
121
+ tid = wp.tid()
122
+ increment_snippet(x, out, tid)
123
+
124
+ N = 128
125
+ x = wp.array(np.arange(N, dtype=np.int32), dtype=wp.int32, device=device)
126
+ out = wp.zeros(N, dtype=wp.int32, device=device)
127
+
128
+ wp.launch(kernel=increment, dim=N, inputs=[x], outputs=[out], device=device)
129
+
130
+ assert_np_equal(out.numpy(), np.arange(1, N + 1, 1, dtype=np.int32))
131
+
132
+
133
+ class TestSnippets(unittest.TestCase):
134
+ pass
135
+
136
+
137
+ add_function_test(TestSnippets, "test_basic", test_basic, devices=get_unique_cuda_test_devices())
138
+ add_function_test(TestSnippets, "test_shared_memory", test_shared_memory, devices=get_unique_cuda_test_devices())
139
+ add_function_test(TestSnippets, "test_cpu_snippet", test_cpu_snippet, devices=["cpu"])
140
+
141
+
142
+ if __name__ == "__main__":
143
+ unittest.main(verbosity=2)
warp/tests/test_sparse.py CHANGED
@@ -1,8 +1,20 @@
1
+ # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
1
10
  import numpy as np
11
+
2
12
  import warp as wp
13
+ from warp.sparse import bsr_zeros, bsr_set_from_triplets, bsr_get_diag, bsr_diag, bsr_identity, bsr_copy, bsr_scale
14
+ from warp.sparse import bsr_set_transpose, bsr_transposed
15
+ from warp.sparse import bsr_axpy, bsr_mm, bsr_axpy_work_arrays, bsr_mm_work_arrays, bsr_mv
16
+ from warp.tests.unittest_utils import *
3
17
 
4
- from warp.sparse import bsr_zeros, bsr_set_from_triplets, bsr_get_diag, bsr_diag, bsr_set_transpose, bsr_axpy, bsr_mm
5
- from warp.tests.test_base import *
6
18
 
7
19
  wp.init()
8
20
 
@@ -46,45 +58,62 @@ def _bsr_to_dense(bsr):
46
58
 
47
59
 
48
60
  def test_csr_from_triplets(test, device):
61
+ rng = np.random.default_rng(123)
62
+
49
63
  shape = (8, 6)
50
64
  n = 100
51
65
 
52
- rows = wp.array(np.random.randint(0, shape[0], n, dtype=int), dtype=int, device=device)
53
- cols = wp.array(np.random.randint(0, shape[1], n, dtype=int), dtype=int, device=device)
54
- vals = wp.array(np.random.rand(n), dtype=float, device=device)
66
+ rows = wp.array(rng.integers(0, high=shape[0], size=n, dtype=int), dtype=int, device=device)
67
+ cols = wp.array(rng.integers(0, high=shape[1], size=n, dtype=int), dtype=int, device=device)
68
+ vals = wp.array(rng.random(size=n), dtype=float, device=device)
55
69
 
56
70
  ref = _triplets_to_dense(shape, rows, cols, vals)
57
71
 
58
72
  csr = bsr_zeros(shape[0], shape[1], float, device=device)
59
73
  bsr_set_from_triplets(csr, rows, cols, vals)
74
+ test.assertEqual(csr.block_size, 1)
60
75
 
61
76
  res = _bsr_to_dense(csr)
62
77
 
63
- assert_np_equal(ref, res, 0.0001)
78
+ assert_np_equal(res, ref, 0.0001)
64
79
 
65
80
 
66
81
  def test_bsr_from_triplets(test, device):
82
+ rng = np.random.default_rng(123)
83
+
67
84
  block_shape = (3, 2)
68
85
  nrow = 4
69
86
  ncol = 9
70
87
  shape = (block_shape[0] * nrow, block_shape[1] * ncol)
71
88
  n = 50
72
89
 
73
- rows = wp.array(np.random.randint(0, nrow, n, dtype=int), dtype=int, device=device)
74
- cols = wp.array(np.random.randint(0, ncol, n, dtype=int), dtype=int, device=device)
75
- vals = wp.array(np.random.rand(n, block_shape[0], block_shape[1]), dtype=float, device=device)
90
+ rows = wp.array(rng.integers(0, high=nrow, size=n, dtype=int), dtype=int, device=device)
91
+ cols = wp.array(rng.integers(0, high=ncol, size=n, dtype=int), dtype=int, device=device)
92
+ vals = wp.array(rng.random(size=(n, block_shape[0], block_shape[1])), dtype=float, device=device)
76
93
 
77
94
  ref = _triplets_to_dense(shape, rows, cols, vals)
78
95
 
79
96
  bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=float), device=device)
80
97
  bsr_set_from_triplets(bsr, rows, cols, vals)
98
+ test.assertEqual(bsr.block_size, block_shape[0] * block_shape[1])
81
99
 
82
100
  res = _bsr_to_dense(bsr)
83
101
 
84
- assert_np_equal(ref, res, 0.0001)
102
+ assert_np_equal(res, ref, 0.0001)
103
+
104
+ # test zero-length inputs
105
+ bsr_set_from_triplets(
106
+ bsr,
107
+ wp.array([], dtype=int, device=device),
108
+ wp.array([], dtype=int, device=device),
109
+ wp.array([], shape=(0, block_shape[0], block_shape[1]), dtype=float, device=device),
110
+ )
111
+ test.assertEqual(bsr.nnz, 0)
112
+
85
113
 
114
+ def test_bsr_get_set_diag(test, device):
115
+ rng = np.random.default_rng(123)
86
116
 
87
- def test_bsr_get_diag(test, device):
88
117
  block_shape = (3, 3)
89
118
  nrow = 4
90
119
  ncol = 4
@@ -92,7 +121,7 @@ def test_bsr_get_diag(test, device):
92
121
 
93
122
  rows = wp.array([0, 1, 2, 3, 2, 1], dtype=int, device=device)
94
123
  cols = wp.array([1, 1, 1, 3, 2, 2], dtype=int, device=device)
95
- vals_np = np.random.rand(nnz, block_shape[0], block_shape[1])
124
+ vals_np = rng.random(size=(nnz, block_shape[0], block_shape[1]))
96
125
  vals = wp.array(vals_np, dtype=float, device=device)
97
126
 
98
127
  bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=float), device=device)
@@ -106,14 +135,46 @@ def test_bsr_get_diag(test, device):
106
135
  assert_np_equal(diag_np[2], vals_np[4], tol=0.00001)
107
136
  assert_np_equal(diag_np[3], vals_np[3], tol=0.00001)
108
137
 
109
- # Test round-trip
138
+ # Test set_diag/get_diag round-trips with various block types
139
+
140
+ # Array of blocks
110
141
  diag_bsr = bsr_diag(diag)
111
- diag = bsr_get_diag(diag_bsr)
142
+ bsr_get_diag(diag_bsr, out=diag)
112
143
  assert_np_equal(diag_np, diag.numpy())
113
144
 
145
+ diag_scalar_np = rng.random(size=nrow)
146
+ diag_scalar = wp.array(diag_scalar_np, device=device)
147
+ diag_bsr = bsr_diag(diag_scalar)
148
+ diag = bsr_get_diag(diag_bsr)
149
+ assert_np_equal(diag_scalar_np, diag.numpy(), tol=0.000001)
150
+
151
+ # Uniform block diagonal
152
+
153
+ with test.assertRaisesRegex(ValueError, "BsrMatrix block type must be either warp matrix or scalar"):
154
+ # 1d block type -- invalid
155
+ diag_bsr = bsr_diag(diag=vals_np[0, 0], rows_of_blocks=nrow, cols_of_blocks=nrow + 1)
156
+
157
+ diag_bsr = bsr_diag(diag=vals_np[0], rows_of_blocks=nrow, cols_of_blocks=nrow + 1)
158
+ assert diag_bsr.values.shape[0] == nrow
159
+ assert_np_equal(diag_bsr.values.numpy(), np.broadcast_to(vals_np[0], shape=(nrow, *block_shape)), tol=0.000001)
160
+
161
+ diag_bsr = bsr_diag(diag=float(diag_scalar_np[0]), rows_of_blocks=nrow, cols_of_blocks=nrow + 1)
162
+ assert diag_bsr.values.shape[0] == nrow
163
+ assert_np_equal(diag_bsr.values.numpy(), np.full(nrow, diag_scalar_np[0]), tol=0.000001)
164
+
165
+ # Identity matrix
166
+ diag_bsr = bsr_identity(nrow, block_type=wp.mat44, device=device)
167
+ assert diag_bsr.values.shape[0] == nrow
168
+ assert_np_equal(diag_bsr.values.numpy(), np.broadcast_to(np.eye(4), shape=(nrow, 4, 4)), tol=0.000001)
169
+
170
+ diag_csr = bsr_identity(nrow, block_type=wp.float64, device=device)
171
+ assert np.all(diag_csr.values.numpy() == np.ones(nrow, dtype=float))
172
+
114
173
 
115
174
  def make_test_bsr_transpose(block_shape, scalar_type):
116
175
  def test_bsr_transpose(test, device):
176
+ rng = np.random.default_rng(123)
177
+
117
178
  nrow = 4
118
179
  ncol = 5
119
180
  nnz = 6
@@ -121,7 +182,7 @@ def make_test_bsr_transpose(block_shape, scalar_type):
121
182
  rows = wp.array([0, 1, 2, 3, 2, 1], dtype=int, device=device)
122
183
  cols = wp.array([1, 4, 1, 3, 0, 2], dtype=int, device=device)
123
184
 
124
- vals_np = np.random.rand(nnz, block_shape[0], block_shape[1])
185
+ vals_np = rng.random(size=(nnz, block_shape[0], block_shape[1]))
125
186
  vals = wp.array(vals_np, dtype=scalar_type, device=device).reshape((nnz, block_shape[0], block_shape[1]))
126
187
 
127
188
  bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
@@ -134,49 +195,92 @@ def make_test_bsr_transpose(block_shape, scalar_type):
134
195
  bsr_set_transpose(dest=bsr_transposed, src=bsr)
135
196
 
136
197
  res = _bsr_to_dense(bsr_transposed)
198
+ assert_np_equal(res, ref, 0.0001)
137
199
 
138
- assert_np_equal(ref, res, 0.0001)
200
+ if block_shape[0] != block_shape[-1]:
201
+ # test incompatible block shape
202
+ with test.assertRaisesRegex(ValueError, "Destination block shape must be"):
203
+ bsr_set_transpose(dest=bsr, src=bsr)
139
204
 
140
205
  return test_bsr_transpose
141
206
 
142
207
 
208
+ def test_bsr_copy_scale(test, device):
209
+ nrow = 6
210
+ bsize = 2
211
+
212
+ diag_bsr = bsr_diag(diag=np.eye(bsize, dtype=float) * 2.0, rows_of_blocks=nrow)
213
+ diag_copy = bsr_copy(diag_bsr, scalar_type=wp.float64)
214
+
215
+ test.assertTrue(wp.types.types_equal(diag_copy.values.dtype, wp.mat(shape=(bsize, bsize), dtype=wp.float64)))
216
+ bsr_scale(x=diag_copy, alpha=0.5)
217
+
218
+ res = _bsr_to_dense(diag_copy)
219
+ ref = np.eye(nrow * bsize)
220
+ assert_np_equal(res, ref, 0.0001)
221
+
222
+ bsr_scale(x=diag_copy, alpha=0.0)
223
+ test.assertEqual(diag_copy.nrow, nrow)
224
+ test.assertEqual(diag_copy.ncol, nrow)
225
+ test.assertEqual(diag_copy.nnz, 0)
226
+
227
+
143
228
  def make_test_bsr_axpy(block_shape, scalar_type):
144
229
  def test_bsr_axpy(test, device):
230
+ rng = np.random.default_rng(123)
231
+
145
232
  nrow = 2
146
233
  ncol = 3
147
234
  nnz = 6
148
235
 
149
- alpha = -1.0
150
- beta = 2.0
236
+ alphas = [-1.0, 0.0, 1.0]
237
+ betas = [2.0, -1.0, 0.0]
151
238
 
152
- x_rows = wp.array(np.random.randint(0, nrow, nnz, dtype=int), dtype=int, device=device)
153
- x_cols = wp.array(np.random.randint(0, ncol, nnz, dtype=int), dtype=int, device=device)
154
- x_vals = wp.array(np.random.rand(nnz, block_shape[0], block_shape[1]), dtype=scalar_type, device=device)
239
+ x_rows = wp.array(rng.integers(0, high=nrow, size=nnz, dtype=int), dtype=int, device=device)
240
+ x_cols = wp.array(rng.integers(0, high=ncol, size=nnz, dtype=int), dtype=int, device=device)
241
+ x_vals = wp.array(rng.random(size=(nnz, block_shape[0], block_shape[1])), dtype=scalar_type, device=device)
155
242
  x_vals = x_vals.reshape((nnz, block_shape[0], block_shape[1]))
156
243
 
157
244
  x = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
158
245
  bsr_set_from_triplets(x, x_rows, x_cols, x_vals)
159
246
 
160
- y_rows = wp.array(np.random.randint(0, nrow, nnz, dtype=int), dtype=int, device=device)
161
- y_cols = wp.array(np.random.randint(0, ncol, nnz, dtype=int), dtype=int, device=device)
162
- y_vals = wp.array(np.random.rand(nnz, block_shape[0], block_shape[1]), dtype=scalar_type, device=device)
247
+ y_rows = wp.array(rng.integers(0, high=nrow, size=nnz, dtype=int), dtype=int, device=device)
248
+ y_cols = wp.array(rng.integers(0, high=ncol, size=nnz, dtype=int), dtype=int, device=device)
249
+ y_vals = wp.array(rng.random(size=(nnz, block_shape[0], block_shape[1])), dtype=scalar_type, device=device)
163
250
  y_vals = y_vals.reshape((nnz, block_shape[0], block_shape[1]))
164
251
 
165
252
  y = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
166
253
  bsr_set_from_triplets(y, y_rows, y_cols, y_vals)
167
254
 
168
- ref = alpha * _bsr_to_dense(x) + beta * _bsr_to_dense(y)
255
+ work_arrays = bsr_axpy_work_arrays()
256
+ for alpha, beta in zip(alphas, betas):
257
+ ref = alpha * _bsr_to_dense(x) + beta * _bsr_to_dense(y)
258
+ if beta == 0.0:
259
+ y = bsr_axpy(x, alpha=alpha, beta=beta, work_arrays=work_arrays)
260
+ else:
261
+ bsr_axpy(x, y, alpha, beta, work_arrays=work_arrays)
169
262
 
170
- bsr_axpy(x, y, alpha, beta)
263
+ res = _bsr_to_dense(y)
264
+ assert_np_equal(res, ref, 0.0001)
171
265
 
266
+ # test aliasing
267
+ ref = 3.0 * _bsr_to_dense(y)
268
+ bsr_axpy(y, y, alpha=1.0, beta=2.0)
172
269
  res = _bsr_to_dense(y)
173
- assert_np_equal(ref, res, 0.0001)
270
+ assert_np_equal(res, ref, 0.0001)
271
+
272
+ # test incompatible shapes
273
+ y.ncol = y.ncol + 1
274
+ with test.assertRaisesRegex(ValueError, "Matrices must have the same number of rows and columns"):
275
+ bsr_axpy(x, y)
174
276
 
175
277
  return test_bsr_axpy
176
278
 
177
279
 
178
280
  def make_test_bsr_mm(block_shape, scalar_type):
179
281
  def test_bsr_mm(test, device):
282
+ rng = np.random.default_rng(123)
283
+
180
284
  x_nrow = 3
181
285
  x_ncol = 2
182
286
  x_block_shape = block_shape
@@ -191,72 +295,166 @@ def make_test_bsr_mm(block_shape, scalar_type):
191
295
 
192
296
  nnz = 6
193
297
 
194
- alpha = -1.0
195
- beta = 2.0
298
+ alphas = [-1.0, 0.0, 1.0]
299
+ betas = [2.0, -1.0, 0.0]
196
300
 
197
- x_rows = wp.array(np.random.randint(0, x_nrow, nnz, dtype=int), dtype=int, device=device)
198
- x_cols = wp.array(np.random.randint(0, x_ncol, nnz, dtype=int), dtype=int, device=device)
199
- x_vals = wp.array(np.random.rand(nnz, x_block_shape[0], x_block_shape[1]), dtype=scalar_type, device=device)
301
+ x_rows = wp.array(rng.integers(0, high=x_nrow, size=nnz, dtype=int), dtype=int, device=device)
302
+ x_cols = wp.array(rng.integers(0, high=x_ncol, size=nnz, dtype=int), dtype=int, device=device)
303
+ x_vals = wp.array(rng.random(size=(nnz, x_block_shape[0], x_block_shape[1])), dtype=scalar_type, device=device)
200
304
  x_vals = x_vals.reshape((nnz, x_block_shape[0], x_block_shape[1]))
201
305
 
202
306
  x = bsr_zeros(x_nrow, x_ncol, wp.types.matrix(shape=x_block_shape, dtype=scalar_type), device=device)
203
307
  bsr_set_from_triplets(x, x_rows, x_cols, x_vals)
204
308
 
205
- y_rows = wp.array(np.random.randint(0, y_nrow, nnz, dtype=int), dtype=int, device=device)
206
- y_cols = wp.array(np.random.randint(0, y_ncol, nnz, dtype=int), dtype=int, device=device)
207
- y_vals = wp.array(np.random.rand(nnz, y_block_shape[0], y_block_shape[1]), dtype=scalar_type, device=device)
309
+ y_rows = wp.array(rng.integers(0, high=y_nrow, size=nnz, dtype=int), dtype=int, device=device)
310
+ y_cols = wp.array(rng.integers(0, high=y_ncol, size=nnz, dtype=int), dtype=int, device=device)
311
+ y_vals = wp.array(rng.random(size=(nnz, y_block_shape[0], y_block_shape[1])), dtype=scalar_type, device=device)
208
312
  y_vals = y_vals.reshape((nnz, y_block_shape[0], y_block_shape[1]))
209
313
 
210
314
  y = bsr_zeros(y_nrow, y_ncol, wp.types.matrix(shape=y_block_shape, dtype=scalar_type), device=device)
211
315
  bsr_set_from_triplets(y, y_rows, y_cols, y_vals)
212
316
 
213
- z_rows = wp.array(np.random.randint(0, z_nrow, nnz, dtype=int), dtype=int, device=device)
214
- z_cols = wp.array(np.random.randint(0, z_ncol, nnz, dtype=int), dtype=int, device=device)
215
- z_vals = wp.array(np.random.rand(nnz, z_block_shape[0], z_block_shape[1]), dtype=scalar_type, device=device)
317
+ z_rows = wp.array(rng.integers(0, high=z_nrow, size=nnz, dtype=int), dtype=int, device=device)
318
+ z_cols = wp.array(rng.integers(0, high=z_ncol, size=nnz, dtype=int), dtype=int, device=device)
319
+ z_vals = wp.array(rng.random(size=(nnz, z_block_shape[0], z_block_shape[1])), dtype=scalar_type, device=device)
216
320
  z_vals = z_vals.reshape((nnz, z_block_shape[0], z_block_shape[1]))
217
321
 
218
322
  z = bsr_zeros(z_nrow, z_ncol, wp.types.matrix(shape=z_block_shape, dtype=scalar_type), device=device)
219
323
  bsr_set_from_triplets(z, z_rows, z_cols, z_vals)
220
324
 
221
- ref = alpha * (_bsr_to_dense(x) @ _bsr_to_dense(y)) + beta * _bsr_to_dense(z)
325
+ work_arrays = bsr_mm_work_arrays()
326
+ for alpha, beta in zip(alphas, betas):
327
+ ref = alpha * (_bsr_to_dense(x) @ _bsr_to_dense(y)) + beta * _bsr_to_dense(z)
328
+
329
+ bsr_mm(x, y, z, alpha, beta, work_arrays=work_arrays)
330
+
331
+ res = _bsr_to_dense(z)
332
+ assert_np_equal(res, ref, 0.0001)
333
+
334
+ # test aliasing of matrix arguments
335
+ # x = alpha * z * x + beta * x
336
+ alpha, beta = alphas[0], betas[0]
337
+ ref = alpha * (_bsr_to_dense(z) @ _bsr_to_dense(x)) + beta * _bsr_to_dense(x)
338
+ bsr_mm(z, x, x, alpha, beta)
222
339
 
223
- bsr_mm(x, y, z, alpha, beta)
340
+ res = _bsr_to_dense(x)
341
+ assert_np_equal(res, ref, 0.0001)
342
+
343
+ # z = alpha * z * z + beta * z
344
+ ref = alpha * (_bsr_to_dense(z) @ _bsr_to_dense(z)) + beta * _bsr_to_dense(z)
345
+ bsr_mm(z, z, z, alpha, beta)
224
346
 
225
347
  res = _bsr_to_dense(z)
226
- assert_np_equal(ref, res, 0.0001)
348
+ assert_np_equal(res, ref, 0.0001)
349
+
350
+ # test incompatible shapes
351
+ if block_shape[0] != block_shape[-1]:
352
+ with test.assertRaisesRegex(ValueError, "Incompatible block sizes"):
353
+ bsr_mm(z, y)
354
+
355
+ y.ncol = y.ncol * 2
356
+ with test.assertRaisesRegex(ValueError, "Incompatible number of rows/columns"):
357
+ bsr_mm(y, z)
227
358
 
228
359
  return test_bsr_mm
229
360
 
230
361
 
231
- def register(parent):
232
- devices = get_test_devices()
362
+ def make_test_bsr_mv(block_shape, scalar_type):
363
+ def test_bsr_mv(test, device):
364
+ rng = np.random.default_rng(123)
233
365
 
234
- class TestSparse(parent):
235
- pass
366
+ nrow = 2
367
+ ncol = 3
368
+ nnz = 6
236
369
 
237
- add_function_test(TestSparse, "test_csr_from_triplets", test_csr_from_triplets, devices=devices)
238
- add_function_test(TestSparse, "test_bsr_from_triplets", test_bsr_from_triplets, devices=devices)
239
- add_function_test(TestSparse, "test_bsr_get_diag", test_bsr_get_diag, devices=devices)
370
+ alphas = [-1.0, 0.0, 1.0]
371
+ betas = [2.0, -1.0, 0.0]
372
+ A_rows = wp.array(rng.integers(0, high=nrow, size=nnz, dtype=int), dtype=int, device=device)
373
+ A_cols = wp.array(rng.integers(0, high=ncol, size=nnz, dtype=int), dtype=int, device=device)
374
+ A_vals = wp.array(rng.random(size=(nnz, block_shape[0], block_shape[1])), dtype=scalar_type, device=device)
375
+ A_vals = A_vals.reshape((nnz, block_shape[0], block_shape[1]))
240
376
 
241
- add_function_test(TestSparse, "test_csr_transpose", make_test_bsr_transpose((1, 1), wp.float32), devices=devices)
242
- add_function_test(
243
- TestSparse, "test_bsr_transpose_1_3", make_test_bsr_transpose((1, 3), wp.float32), devices=devices
244
- )
245
- add_function_test(
246
- TestSparse, "test_bsr_transpose_3_3", make_test_bsr_transpose((3, 3), wp.float64), devices=devices
247
- )
377
+ A = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
378
+ bsr_set_from_triplets(A, A_rows, A_cols, A_vals)
379
+
380
+ if block_shape[1] == 1:
381
+ x = wp.array(rng.random(size=ncol), dtype=scalar_type, device=device)
382
+ else:
383
+ x = wp.array(
384
+ rng.random(size=(ncol, block_shape[1])),
385
+ dtype=wp.vec(length=block_shape[1], dtype=scalar_type),
386
+ device=device,
387
+ )
388
+
389
+ if block_shape[0] == 1:
390
+ y = wp.array(rng.random(size=nrow), dtype=scalar_type, device=device)
391
+ else:
392
+ y = wp.array(
393
+ rng.random(size=(nrow, block_shape[0])),
394
+ dtype=wp.vec(length=block_shape[0], dtype=scalar_type),
395
+ device=device,
396
+ )
397
+
398
+ work_buffer = wp.empty_like(y)
399
+ for alpha, beta in zip(alphas, betas):
400
+ ref = alpha * _bsr_to_dense(A) @ x.numpy().flatten() + beta * y.numpy().flatten()
401
+ if beta == 0.0:
402
+ y = bsr_mv(A, x, alpha=alpha, beta=beta, work_buffer=work_buffer)
403
+ else:
404
+ bsr_mv(A, x, y, alpha, beta, work_buffer=work_buffer)
405
+
406
+ res = y.numpy().flatten()
407
+ assert_np_equal(res, ref, 0.0001)
408
+
409
+ # test aliasing
410
+ alpha, beta = alphas[0], betas[0]
411
+ AAt = bsr_mm(A, bsr_transposed(A))
412
+ ref = alpha * _bsr_to_dense(AAt) @ y.numpy().flatten() + beta * y.numpy().flatten()
413
+ bsr_mv(AAt, y, y, alpha, beta)
414
+ res = y.numpy().flatten()
415
+ assert_np_equal(res, ref, 0.0001)
416
+
417
+ A.ncol = A.ncol + 1
418
+ with test.assertRaisesRegex(ValueError, "Number of columns"):
419
+ bsr_mv(A, x, y)
420
+
421
+ A.ncol = A.ncol - 1
422
+ A.nrow = A.nrow - 1
423
+ with test.assertRaisesRegex(ValueError, "Number of rows"):
424
+ bsr_mv(A, x, y)
425
+
426
+ return test_bsr_mv
427
+
428
+
429
+ devices = get_test_devices()
430
+
431
+
432
+ class TestSparse(unittest.TestCase):
433
+ pass
434
+
435
+
436
+ add_function_test(TestSparse, "test_csr_from_triplets", test_csr_from_triplets, devices=devices)
437
+ add_function_test(TestSparse, "test_bsr_from_triplets", test_bsr_from_triplets, devices=devices)
438
+ add_function_test(TestSparse, "test_bsr_get_diag", test_bsr_get_set_diag, devices=devices)
439
+ add_function_test(TestSparse, "test_bsr_copy_scale", test_bsr_copy_scale, devices=devices)
440
+
441
+ add_function_test(TestSparse, "test_csr_transpose", make_test_bsr_transpose((1, 1), wp.float32), devices=devices)
442
+ add_function_test(TestSparse, "test_bsr_transpose_1_3", make_test_bsr_transpose((1, 3), wp.float32), devices=devices)
443
+ add_function_test(TestSparse, "test_bsr_transpose_3_3", make_test_bsr_transpose((3, 3), wp.float64), devices=devices)
248
444
 
249
- add_function_test(TestSparse, "test_csr_axpy", make_test_bsr_axpy((1, 1), wp.float32), devices=devices)
250
- add_function_test(TestSparse, "test_bsr_axpy_1_3", make_test_bsr_axpy((1, 3), wp.float32), devices=devices)
251
- add_function_test(TestSparse, "test_bsr_axpy_3_3", make_test_bsr_axpy((3, 3), wp.float64), devices=devices)
445
+ add_function_test(TestSparse, "test_csr_axpy", make_test_bsr_axpy((1, 1), wp.float32), devices=devices)
446
+ add_function_test(TestSparse, "test_bsr_axpy_1_3", make_test_bsr_axpy((1, 3), wp.float32), devices=devices)
447
+ add_function_test(TestSparse, "test_bsr_axpy_3_3", make_test_bsr_axpy((3, 3), wp.float64), devices=devices)
252
448
 
253
- add_function_test(TestSparse, "test_csr_mm", make_test_bsr_mm((1, 1), wp.float32), devices=devices)
254
- add_function_test(TestSparse, "test_bsr_mm_1_3", make_test_bsr_mm((1, 3), wp.float32), devices=devices)
255
- add_function_test(TestSparse, "test_bsr_mm_3_3", make_test_bsr_mm((3, 3), wp.float64), devices=devices)
449
+ add_function_test(TestSparse, "test_csr_mm", make_test_bsr_mm((1, 1), wp.float32), devices=devices)
450
+ add_function_test(TestSparse, "test_bsr_mm_1_3", make_test_bsr_mm((1, 3), wp.float32), devices=devices)
451
+ add_function_test(TestSparse, "test_bsr_mm_3_3", make_test_bsr_mm((3, 3), wp.float64), devices=devices)
256
452
 
257
- return TestSparse
453
+ add_function_test(TestSparse, "test_csr_mv", make_test_bsr_mv((1, 1), wp.float32), devices=devices)
454
+ add_function_test(TestSparse, "test_bsr_mv_1_3", make_test_bsr_mv((1, 3), wp.float32), devices=devices)
455
+ add_function_test(TestSparse, "test_bsr_mv_3_3", make_test_bsr_mv((3, 3), wp.float64), devices=devices)
258
456
 
259
457
 
260
458
  if __name__ == "__main__":
261
- c = register(unittest.TestCase)
459
+ wp.build.clear_kernel_cache()
262
460
  unittest.main(verbosity=2)