warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,143 @@
1
+ import unittest
2
+
3
+ import numpy as np
4
+
5
+ import warp as wp
6
+ from warp.tests.unittest_utils import *
7
+
8
+ wp.init()
9
+
10
+
11
+ def test_basic(test, device):
12
+ snippet = """
13
+ out[tid] = a * x[tid] + y[tid];
14
+ """
15
+ adj_snippet = """
16
+ adj_a = x[tid] * adj_out[tid];
17
+ adj_x[tid] = a * adj_out[tid];
18
+ adj_y[tid] = adj_out[tid];
19
+ """
20
+
21
+ @wp.func_native(snippet, adj_snippet)
22
+ def saxpy(
23
+ a: wp.float32,
24
+ x: wp.array(dtype=wp.float32),
25
+ y: wp.array(dtype=wp.float32),
26
+ out: wp.array(dtype=wp.float32),
27
+ tid: int,
28
+ ):
29
+ ...
30
+
31
+ @wp.kernel
32
+ def saxpy_cu(
33
+ a: wp.float32, x: wp.array(dtype=wp.float32), y: wp.array(dtype=wp.float32), out: wp.array(dtype=wp.float32)
34
+ ):
35
+ tid = wp.tid()
36
+ saxpy(a, x, y, out, tid)
37
+
38
+ @wp.kernel
39
+ def saxpy_py(
40
+ a: wp.float32, x: wp.array(dtype=wp.float32), y: wp.array(dtype=wp.float32), out: wp.array(dtype=wp.float32)
41
+ ):
42
+ tid = wp.tid()
43
+ out[tid] = a * x[tid] + y[tid]
44
+
45
+ N = 128
46
+
47
+ a1 = 2.0
48
+ x1 = wp.array(np.arange(N, dtype=np.float32), dtype=wp.float32, device=device, requires_grad=True)
49
+ y1 = wp.zeros_like(x1)
50
+ out1 = wp.array(np.arange(N, dtype=np.float32), dtype=wp.float32, device=device)
51
+ adj_out1 = wp.array(np.ones(N, dtype=np.float32), dtype=wp.float32, device=device)
52
+
53
+ a2 = 2.0
54
+ x2 = wp.array(np.arange(N, dtype=np.float32), dtype=wp.float32, device=device, requires_grad=True)
55
+ y2 = wp.zeros_like(x2)
56
+ out2 = wp.array(np.arange(N, dtype=np.float32), dtype=wp.float32, device=device)
57
+ adj_out2 = wp.array(np.ones(N, dtype=np.float32), dtype=wp.float32, device=device)
58
+
59
+ tape = wp.Tape()
60
+
61
+ with tape:
62
+ wp.launch(kernel=saxpy_cu, dim=N, inputs=[a1, x1, y1], outputs=[out1], device=device)
63
+ wp.launch(kernel=saxpy_py, dim=N, inputs=[a2, x2, y2], outputs=[out2], device=device)
64
+
65
+ tape.backward(grads={out1: adj_out1, out2: adj_out2})
66
+
67
+ # test forward snippet
68
+ assert_np_equal(out1.numpy(), out2.numpy())
69
+
70
+ # test backward snippet
71
+ assert_np_equal(x1.grad.numpy(), a1 * np.ones(N, dtype=np.float32))
72
+ assert_np_equal(x1.grad.numpy(), x2.grad.numpy())
73
+
74
+ assert_np_equal(y1.grad.numpy(), np.ones(N, dtype=np.float32))
75
+ assert_np_equal(y1.grad.numpy(), y2.grad.numpy())
76
+
77
+
78
+ def test_shared_memory(test, device):
79
+ snippet = """
80
+ __shared__ int s[128];
81
+
82
+ s[tid] = d[tid];
83
+ __syncthreads();
84
+ d[tid] = s[N - tid - 1];
85
+ """
86
+
87
+ @wp.func_native(snippet)
88
+ def reverse(d: wp.array(dtype=int), N: int, tid: int):
89
+ return
90
+
91
+ @wp.kernel
92
+ def reverse_kernel(d: wp.array(dtype=int), N: int):
93
+ tid = wp.tid()
94
+ reverse(d, N, tid)
95
+
96
+ N = 128
97
+ x = wp.array(np.arange(N, dtype=int), dtype=int, device=device)
98
+ y = np.arange(127, -1, -1, dtype=int)
99
+
100
+ wp.launch(kernel=reverse_kernel, dim=N, inputs=[x, N], device=device)
101
+
102
+ assert_np_equal(x.numpy(), y)
103
+
104
+
105
+ def test_cpu_snippet(test, device):
106
+ snippet = """
107
+ int inc = 1;
108
+ out[tid] = x[tid] + inc;
109
+ """
110
+
111
+ @wp.func_native(snippet)
112
+ def increment_snippet(
113
+ x: wp.array(dtype=wp.int32),
114
+ out: wp.array(dtype=wp.int32),
115
+ tid: int,
116
+ ):
117
+ ...
118
+
119
+ @wp.kernel
120
+ def increment(x: wp.array(dtype=wp.int32), out: wp.array(dtype=wp.int32)):
121
+ tid = wp.tid()
122
+ increment_snippet(x, out, tid)
123
+
124
+ N = 128
125
+ x = wp.array(np.arange(N, dtype=np.int32), dtype=wp.int32, device=device)
126
+ out = wp.zeros(N, dtype=wp.int32, device=device)
127
+
128
+ wp.launch(kernel=increment, dim=N, inputs=[x], outputs=[out], device=device)
129
+
130
+ assert_np_equal(out.numpy(), np.arange(1, N + 1, 1, dtype=np.int32))
131
+
132
+
133
+ class TestSnippets(unittest.TestCase):
134
+ pass
135
+
136
+
137
+ add_function_test(TestSnippets, "test_basic", test_basic, devices=get_unique_cuda_test_devices())
138
+ add_function_test(TestSnippets, "test_shared_memory", test_shared_memory, devices=get_unique_cuda_test_devices())
139
+ add_function_test(TestSnippets, "test_cpu_snippet", test_cpu_snippet, devices=["cpu"])
140
+
141
+
142
+ if __name__ == "__main__":
143
+ unittest.main(verbosity=2)
warp/tests/test_sparse.py CHANGED
@@ -1,8 +1,20 @@
1
+ # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
1
10
  import numpy as np
11
+
2
12
  import warp as wp
13
+ from warp.sparse import bsr_zeros, bsr_set_from_triplets, bsr_get_diag, bsr_diag, bsr_identity, bsr_copy, bsr_scale
14
+ from warp.sparse import bsr_set_transpose, bsr_transposed
15
+ from warp.sparse import bsr_axpy, bsr_mm, bsr_axpy_work_arrays, bsr_mm_work_arrays, bsr_mv
16
+ from warp.tests.unittest_utils import *
3
17
 
4
- from warp.sparse import bsr_zeros, bsr_set_from_triplets, bsr_get_diag, bsr_diag, bsr_set_transpose, bsr_axpy, bsr_mm
5
- from warp.tests.test_base import *
6
18
 
7
19
  wp.init()
8
20
 
@@ -46,45 +58,62 @@ def _bsr_to_dense(bsr):
46
58
 
47
59
 
48
60
  def test_csr_from_triplets(test, device):
61
+ rng = np.random.default_rng(123)
62
+
49
63
  shape = (8, 6)
50
64
  n = 100
51
65
 
52
- rows = wp.array(np.random.randint(0, shape[0], n, dtype=int), dtype=int, device=device)
53
- cols = wp.array(np.random.randint(0, shape[1], n, dtype=int), dtype=int, device=device)
54
- vals = wp.array(np.random.rand(n), dtype=float, device=device)
66
+ rows = wp.array(rng.integers(0, high=shape[0], size=n, dtype=int), dtype=int, device=device)
67
+ cols = wp.array(rng.integers(0, high=shape[1], size=n, dtype=int), dtype=int, device=device)
68
+ vals = wp.array(rng.random(size=n), dtype=float, device=device)
55
69
 
56
70
  ref = _triplets_to_dense(shape, rows, cols, vals)
57
71
 
58
72
  csr = bsr_zeros(shape[0], shape[1], float, device=device)
59
73
  bsr_set_from_triplets(csr, rows, cols, vals)
74
+ test.assertEqual(csr.block_size, 1)
60
75
 
61
76
  res = _bsr_to_dense(csr)
62
77
 
63
- assert_np_equal(ref, res, 0.0001)
78
+ assert_np_equal(res, ref, 0.0001)
64
79
 
65
80
 
66
81
  def test_bsr_from_triplets(test, device):
82
+ rng = np.random.default_rng(123)
83
+
67
84
  block_shape = (3, 2)
68
85
  nrow = 4
69
86
  ncol = 9
70
87
  shape = (block_shape[0] * nrow, block_shape[1] * ncol)
71
88
  n = 50
72
89
 
73
- rows = wp.array(np.random.randint(0, nrow, n, dtype=int), dtype=int, device=device)
74
- cols = wp.array(np.random.randint(0, ncol, n, dtype=int), dtype=int, device=device)
75
- vals = wp.array(np.random.rand(n, block_shape[0], block_shape[1]), dtype=float, device=device)
90
+ rows = wp.array(rng.integers(0, high=nrow, size=n, dtype=int), dtype=int, device=device)
91
+ cols = wp.array(rng.integers(0, high=ncol, size=n, dtype=int), dtype=int, device=device)
92
+ vals = wp.array(rng.random(size=(n, block_shape[0], block_shape[1])), dtype=float, device=device)
76
93
 
77
94
  ref = _triplets_to_dense(shape, rows, cols, vals)
78
95
 
79
96
  bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=float), device=device)
80
97
  bsr_set_from_triplets(bsr, rows, cols, vals)
98
+ test.assertEqual(bsr.block_size, block_shape[0] * block_shape[1])
81
99
 
82
100
  res = _bsr_to_dense(bsr)
83
101
 
84
- assert_np_equal(ref, res, 0.0001)
102
+ assert_np_equal(res, ref, 0.0001)
103
+
104
+ # test zero-length inputs
105
+ bsr_set_from_triplets(
106
+ bsr,
107
+ wp.array([], dtype=int, device=device),
108
+ wp.array([], dtype=int, device=device),
109
+ wp.array([], shape=(0, block_shape[0], block_shape[1]), dtype=float, device=device),
110
+ )
111
+ test.assertEqual(bsr.nnz, 0)
112
+
85
113
 
114
+ def test_bsr_get_set_diag(test, device):
115
+ rng = np.random.default_rng(123)
86
116
 
87
- def test_bsr_get_diag(test, device):
88
117
  block_shape = (3, 3)
89
118
  nrow = 4
90
119
  ncol = 4
@@ -92,7 +121,7 @@ def test_bsr_get_diag(test, device):
92
121
 
93
122
  rows = wp.array([0, 1, 2, 3, 2, 1], dtype=int, device=device)
94
123
  cols = wp.array([1, 1, 1, 3, 2, 2], dtype=int, device=device)
95
- vals_np = np.random.rand(nnz, block_shape[0], block_shape[1])
124
+ vals_np = rng.random(size=(nnz, block_shape[0], block_shape[1]))
96
125
  vals = wp.array(vals_np, dtype=float, device=device)
97
126
 
98
127
  bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=float), device=device)
@@ -106,14 +135,46 @@ def test_bsr_get_diag(test, device):
106
135
  assert_np_equal(diag_np[2], vals_np[4], tol=0.00001)
107
136
  assert_np_equal(diag_np[3], vals_np[3], tol=0.00001)
108
137
 
109
- # Test round-trip
138
+ # Test set_diag/get_diag round-trips with various block types
139
+
140
+ # Array of blocks
110
141
  diag_bsr = bsr_diag(diag)
111
- diag = bsr_get_diag(diag_bsr)
142
+ bsr_get_diag(diag_bsr, out=diag)
112
143
  assert_np_equal(diag_np, diag.numpy())
113
144
 
145
+ diag_scalar_np = rng.random(size=nrow)
146
+ diag_scalar = wp.array(diag_scalar_np, device=device)
147
+ diag_bsr = bsr_diag(diag_scalar)
148
+ diag = bsr_get_diag(diag_bsr)
149
+ assert_np_equal(diag_scalar_np, diag.numpy(), tol=0.000001)
150
+
151
+ # Uniform block diagonal
152
+
153
+ with test.assertRaisesRegex(ValueError, "BsrMatrix block type must be either warp matrix or scalar"):
154
+ # 1d block type -- invalid
155
+ diag_bsr = bsr_diag(diag=vals_np[0, 0], rows_of_blocks=nrow, cols_of_blocks=nrow + 1)
156
+
157
+ diag_bsr = bsr_diag(diag=vals_np[0], rows_of_blocks=nrow, cols_of_blocks=nrow + 1)
158
+ assert diag_bsr.values.shape[0] == nrow
159
+ assert_np_equal(diag_bsr.values.numpy(), np.broadcast_to(vals_np[0], shape=(nrow, *block_shape)), tol=0.000001)
160
+
161
+ diag_bsr = bsr_diag(diag=float(diag_scalar_np[0]), rows_of_blocks=nrow, cols_of_blocks=nrow + 1)
162
+ assert diag_bsr.values.shape[0] == nrow
163
+ assert_np_equal(diag_bsr.values.numpy(), np.full(nrow, diag_scalar_np[0]), tol=0.000001)
164
+
165
+ # Identity matrix
166
+ diag_bsr = bsr_identity(nrow, block_type=wp.mat44, device=device)
167
+ assert diag_bsr.values.shape[0] == nrow
168
+ assert_np_equal(diag_bsr.values.numpy(), np.broadcast_to(np.eye(4), shape=(nrow, 4, 4)), tol=0.000001)
169
+
170
+ diag_csr = bsr_identity(nrow, block_type=wp.float64, device=device)
171
+ assert np.all(diag_csr.values.numpy() == np.ones(nrow, dtype=float))
172
+
114
173
 
115
174
  def make_test_bsr_transpose(block_shape, scalar_type):
116
175
  def test_bsr_transpose(test, device):
176
+ rng = np.random.default_rng(123)
177
+
117
178
  nrow = 4
118
179
  ncol = 5
119
180
  nnz = 6
@@ -121,7 +182,7 @@ def make_test_bsr_transpose(block_shape, scalar_type):
121
182
  rows = wp.array([0, 1, 2, 3, 2, 1], dtype=int, device=device)
122
183
  cols = wp.array([1, 4, 1, 3, 0, 2], dtype=int, device=device)
123
184
 
124
- vals_np = np.random.rand(nnz, block_shape[0], block_shape[1])
185
+ vals_np = rng.random(size=(nnz, block_shape[0], block_shape[1]))
125
186
  vals = wp.array(vals_np, dtype=scalar_type, device=device).reshape((nnz, block_shape[0], block_shape[1]))
126
187
 
127
188
  bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
@@ -134,49 +195,92 @@ def make_test_bsr_transpose(block_shape, scalar_type):
134
195
  bsr_set_transpose(dest=bsr_transposed, src=bsr)
135
196
 
136
197
  res = _bsr_to_dense(bsr_transposed)
198
+ assert_np_equal(res, ref, 0.0001)
137
199
 
138
- assert_np_equal(ref, res, 0.0001)
200
+ if block_shape[0] != block_shape[-1]:
201
+ # test incompatible block shape
202
+ with test.assertRaisesRegex(ValueError, "Destination block shape must be"):
203
+ bsr_set_transpose(dest=bsr, src=bsr)
139
204
 
140
205
  return test_bsr_transpose
141
206
 
142
207
 
208
+ def test_bsr_copy_scale(test, device):
209
+ nrow = 6
210
+ bsize = 2
211
+
212
+ diag_bsr = bsr_diag(diag=np.eye(bsize, dtype=float) * 2.0, rows_of_blocks=nrow)
213
+ diag_copy = bsr_copy(diag_bsr, scalar_type=wp.float64)
214
+
215
+ test.assertTrue(wp.types.types_equal(diag_copy.values.dtype, wp.mat(shape=(bsize, bsize), dtype=wp.float64)))
216
+ bsr_scale(x=diag_copy, alpha=0.5)
217
+
218
+ res = _bsr_to_dense(diag_copy)
219
+ ref = np.eye(nrow * bsize)
220
+ assert_np_equal(res, ref, 0.0001)
221
+
222
+ bsr_scale(x=diag_copy, alpha=0.0)
223
+ test.assertEqual(diag_copy.nrow, nrow)
224
+ test.assertEqual(diag_copy.ncol, nrow)
225
+ test.assertEqual(diag_copy.nnz, 0)
226
+
227
+
143
228
  def make_test_bsr_axpy(block_shape, scalar_type):
144
229
  def test_bsr_axpy(test, device):
230
+ rng = np.random.default_rng(123)
231
+
145
232
  nrow = 2
146
233
  ncol = 3
147
234
  nnz = 6
148
235
 
149
- alpha = -1.0
150
- beta = 2.0
236
+ alphas = [-1.0, 0.0, 1.0]
237
+ betas = [2.0, -1.0, 0.0]
151
238
 
152
- x_rows = wp.array(np.random.randint(0, nrow, nnz, dtype=int), dtype=int, device=device)
153
- x_cols = wp.array(np.random.randint(0, ncol, nnz, dtype=int), dtype=int, device=device)
154
- x_vals = wp.array(np.random.rand(nnz, block_shape[0], block_shape[1]), dtype=scalar_type, device=device)
239
+ x_rows = wp.array(rng.integers(0, high=nrow, size=nnz, dtype=int), dtype=int, device=device)
240
+ x_cols = wp.array(rng.integers(0, high=ncol, size=nnz, dtype=int), dtype=int, device=device)
241
+ x_vals = wp.array(rng.random(size=(nnz, block_shape[0], block_shape[1])), dtype=scalar_type, device=device)
155
242
  x_vals = x_vals.reshape((nnz, block_shape[0], block_shape[1]))
156
243
 
157
244
  x = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
158
245
  bsr_set_from_triplets(x, x_rows, x_cols, x_vals)
159
246
 
160
- y_rows = wp.array(np.random.randint(0, nrow, nnz, dtype=int), dtype=int, device=device)
161
- y_cols = wp.array(np.random.randint(0, ncol, nnz, dtype=int), dtype=int, device=device)
162
- y_vals = wp.array(np.random.rand(nnz, block_shape[0], block_shape[1]), dtype=scalar_type, device=device)
247
+ y_rows = wp.array(rng.integers(0, high=nrow, size=nnz, dtype=int), dtype=int, device=device)
248
+ y_cols = wp.array(rng.integers(0, high=ncol, size=nnz, dtype=int), dtype=int, device=device)
249
+ y_vals = wp.array(rng.random(size=(nnz, block_shape[0], block_shape[1])), dtype=scalar_type, device=device)
163
250
  y_vals = y_vals.reshape((nnz, block_shape[0], block_shape[1]))
164
251
 
165
252
  y = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
166
253
  bsr_set_from_triplets(y, y_rows, y_cols, y_vals)
167
254
 
168
- ref = alpha * _bsr_to_dense(x) + beta * _bsr_to_dense(y)
255
+ work_arrays = bsr_axpy_work_arrays()
256
+ for alpha, beta in zip(alphas, betas):
257
+ ref = alpha * _bsr_to_dense(x) + beta * _bsr_to_dense(y)
258
+ if beta == 0.0:
259
+ y = bsr_axpy(x, alpha=alpha, beta=beta, work_arrays=work_arrays)
260
+ else:
261
+ bsr_axpy(x, y, alpha, beta, work_arrays=work_arrays)
169
262
 
170
- bsr_axpy(x, y, alpha, beta)
263
+ res = _bsr_to_dense(y)
264
+ assert_np_equal(res, ref, 0.0001)
171
265
 
266
+ # test aliasing
267
+ ref = 3.0 * _bsr_to_dense(y)
268
+ bsr_axpy(y, y, alpha=1.0, beta=2.0)
172
269
  res = _bsr_to_dense(y)
173
- assert_np_equal(ref, res, 0.0001)
270
+ assert_np_equal(res, ref, 0.0001)
271
+
272
+ # test incompatible shapes
273
+ y.ncol = y.ncol + 1
274
+ with test.assertRaisesRegex(ValueError, "Matrices must have the same number of rows and columns"):
275
+ bsr_axpy(x, y)
174
276
 
175
277
  return test_bsr_axpy
176
278
 
177
279
 
178
280
  def make_test_bsr_mm(block_shape, scalar_type):
179
281
  def test_bsr_mm(test, device):
282
+ rng = np.random.default_rng(123)
283
+
180
284
  x_nrow = 3
181
285
  x_ncol = 2
182
286
  x_block_shape = block_shape
@@ -191,72 +295,166 @@ def make_test_bsr_mm(block_shape, scalar_type):
191
295
 
192
296
  nnz = 6
193
297
 
194
- alpha = -1.0
195
- beta = 2.0
298
+ alphas = [-1.0, 0.0, 1.0]
299
+ betas = [2.0, -1.0, 0.0]
196
300
 
197
- x_rows = wp.array(np.random.randint(0, x_nrow, nnz, dtype=int), dtype=int, device=device)
198
- x_cols = wp.array(np.random.randint(0, x_ncol, nnz, dtype=int), dtype=int, device=device)
199
- x_vals = wp.array(np.random.rand(nnz, x_block_shape[0], x_block_shape[1]), dtype=scalar_type, device=device)
301
+ x_rows = wp.array(rng.integers(0, high=x_nrow, size=nnz, dtype=int), dtype=int, device=device)
302
+ x_cols = wp.array(rng.integers(0, high=x_ncol, size=nnz, dtype=int), dtype=int, device=device)
303
+ x_vals = wp.array(rng.random(size=(nnz, x_block_shape[0], x_block_shape[1])), dtype=scalar_type, device=device)
200
304
  x_vals = x_vals.reshape((nnz, x_block_shape[0], x_block_shape[1]))
201
305
 
202
306
  x = bsr_zeros(x_nrow, x_ncol, wp.types.matrix(shape=x_block_shape, dtype=scalar_type), device=device)
203
307
  bsr_set_from_triplets(x, x_rows, x_cols, x_vals)
204
308
 
205
- y_rows = wp.array(np.random.randint(0, y_nrow, nnz, dtype=int), dtype=int, device=device)
206
- y_cols = wp.array(np.random.randint(0, y_ncol, nnz, dtype=int), dtype=int, device=device)
207
- y_vals = wp.array(np.random.rand(nnz, y_block_shape[0], y_block_shape[1]), dtype=scalar_type, device=device)
309
+ y_rows = wp.array(rng.integers(0, high=y_nrow, size=nnz, dtype=int), dtype=int, device=device)
310
+ y_cols = wp.array(rng.integers(0, high=y_ncol, size=nnz, dtype=int), dtype=int, device=device)
311
+ y_vals = wp.array(rng.random(size=(nnz, y_block_shape[0], y_block_shape[1])), dtype=scalar_type, device=device)
208
312
  y_vals = y_vals.reshape((nnz, y_block_shape[0], y_block_shape[1]))
209
313
 
210
314
  y = bsr_zeros(y_nrow, y_ncol, wp.types.matrix(shape=y_block_shape, dtype=scalar_type), device=device)
211
315
  bsr_set_from_triplets(y, y_rows, y_cols, y_vals)
212
316
 
213
- z_rows = wp.array(np.random.randint(0, z_nrow, nnz, dtype=int), dtype=int, device=device)
214
- z_cols = wp.array(np.random.randint(0, z_ncol, nnz, dtype=int), dtype=int, device=device)
215
- z_vals = wp.array(np.random.rand(nnz, z_block_shape[0], z_block_shape[1]), dtype=scalar_type, device=device)
317
+ z_rows = wp.array(rng.integers(0, high=z_nrow, size=nnz, dtype=int), dtype=int, device=device)
318
+ z_cols = wp.array(rng.integers(0, high=z_ncol, size=nnz, dtype=int), dtype=int, device=device)
319
+ z_vals = wp.array(rng.random(size=(nnz, z_block_shape[0], z_block_shape[1])), dtype=scalar_type, device=device)
216
320
  z_vals = z_vals.reshape((nnz, z_block_shape[0], z_block_shape[1]))
217
321
 
218
322
  z = bsr_zeros(z_nrow, z_ncol, wp.types.matrix(shape=z_block_shape, dtype=scalar_type), device=device)
219
323
  bsr_set_from_triplets(z, z_rows, z_cols, z_vals)
220
324
 
221
- ref = alpha * (_bsr_to_dense(x) @ _bsr_to_dense(y)) + beta * _bsr_to_dense(z)
325
+ work_arrays = bsr_mm_work_arrays()
326
+ for alpha, beta in zip(alphas, betas):
327
+ ref = alpha * (_bsr_to_dense(x) @ _bsr_to_dense(y)) + beta * _bsr_to_dense(z)
328
+
329
+ bsr_mm(x, y, z, alpha, beta, work_arrays=work_arrays)
330
+
331
+ res = _bsr_to_dense(z)
332
+ assert_np_equal(res, ref, 0.0001)
333
+
334
+ # test aliasing of matrix arguments
335
+ # x = alpha * z * x + beta * x
336
+ alpha, beta = alphas[0], betas[0]
337
+ ref = alpha * (_bsr_to_dense(z) @ _bsr_to_dense(x)) + beta * _bsr_to_dense(x)
338
+ bsr_mm(z, x, x, alpha, beta)
222
339
 
223
- bsr_mm(x, y, z, alpha, beta)
340
+ res = _bsr_to_dense(x)
341
+ assert_np_equal(res, ref, 0.0001)
342
+
343
+ # z = alpha * z * z + beta * z
344
+ ref = alpha * (_bsr_to_dense(z) @ _bsr_to_dense(z)) + beta * _bsr_to_dense(z)
345
+ bsr_mm(z, z, z, alpha, beta)
224
346
 
225
347
  res = _bsr_to_dense(z)
226
- assert_np_equal(ref, res, 0.0001)
348
+ assert_np_equal(res, ref, 0.0001)
349
+
350
+ # test incompatible shapes
351
+ if block_shape[0] != block_shape[-1]:
352
+ with test.assertRaisesRegex(ValueError, "Incompatible block sizes"):
353
+ bsr_mm(z, y)
354
+
355
+ y.ncol = y.ncol * 2
356
+ with test.assertRaisesRegex(ValueError, "Incompatible number of rows/columns"):
357
+ bsr_mm(y, z)
227
358
 
228
359
  return test_bsr_mm
229
360
 
230
361
 
231
- def register(parent):
232
- devices = get_test_devices()
362
+ def make_test_bsr_mv(block_shape, scalar_type):
363
+ def test_bsr_mv(test, device):
364
+ rng = np.random.default_rng(123)
233
365
 
234
- class TestSparse(parent):
235
- pass
366
+ nrow = 2
367
+ ncol = 3
368
+ nnz = 6
236
369
 
237
- add_function_test(TestSparse, "test_csr_from_triplets", test_csr_from_triplets, devices=devices)
238
- add_function_test(TestSparse, "test_bsr_from_triplets", test_bsr_from_triplets, devices=devices)
239
- add_function_test(TestSparse, "test_bsr_get_diag", test_bsr_get_diag, devices=devices)
370
+ alphas = [-1.0, 0.0, 1.0]
371
+ betas = [2.0, -1.0, 0.0]
372
+ A_rows = wp.array(rng.integers(0, high=nrow, size=nnz, dtype=int), dtype=int, device=device)
373
+ A_cols = wp.array(rng.integers(0, high=ncol, size=nnz, dtype=int), dtype=int, device=device)
374
+ A_vals = wp.array(rng.random(size=(nnz, block_shape[0], block_shape[1])), dtype=scalar_type, device=device)
375
+ A_vals = A_vals.reshape((nnz, block_shape[0], block_shape[1]))
240
376
 
241
- add_function_test(TestSparse, "test_csr_transpose", make_test_bsr_transpose((1, 1), wp.float32), devices=devices)
242
- add_function_test(
243
- TestSparse, "test_bsr_transpose_1_3", make_test_bsr_transpose((1, 3), wp.float32), devices=devices
244
- )
245
- add_function_test(
246
- TestSparse, "test_bsr_transpose_3_3", make_test_bsr_transpose((3, 3), wp.float64), devices=devices
247
- )
377
+ A = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
378
+ bsr_set_from_triplets(A, A_rows, A_cols, A_vals)
379
+
380
+ if block_shape[1] == 1:
381
+ x = wp.array(rng.random(size=ncol), dtype=scalar_type, device=device)
382
+ else:
383
+ x = wp.array(
384
+ rng.random(size=(ncol, block_shape[1])),
385
+ dtype=wp.vec(length=block_shape[1], dtype=scalar_type),
386
+ device=device,
387
+ )
388
+
389
+ if block_shape[0] == 1:
390
+ y = wp.array(rng.random(size=nrow), dtype=scalar_type, device=device)
391
+ else:
392
+ y = wp.array(
393
+ rng.random(size=(nrow, block_shape[0])),
394
+ dtype=wp.vec(length=block_shape[0], dtype=scalar_type),
395
+ device=device,
396
+ )
397
+
398
+ work_buffer = wp.empty_like(y)
399
+ for alpha, beta in zip(alphas, betas):
400
+ ref = alpha * _bsr_to_dense(A) @ x.numpy().flatten() + beta * y.numpy().flatten()
401
+ if beta == 0.0:
402
+ y = bsr_mv(A, x, alpha=alpha, beta=beta, work_buffer=work_buffer)
403
+ else:
404
+ bsr_mv(A, x, y, alpha, beta, work_buffer=work_buffer)
405
+
406
+ res = y.numpy().flatten()
407
+ assert_np_equal(res, ref, 0.0001)
408
+
409
+ # test aliasing
410
+ alpha, beta = alphas[0], betas[0]
411
+ AAt = bsr_mm(A, bsr_transposed(A))
412
+ ref = alpha * _bsr_to_dense(AAt) @ y.numpy().flatten() + beta * y.numpy().flatten()
413
+ bsr_mv(AAt, y, y, alpha, beta)
414
+ res = y.numpy().flatten()
415
+ assert_np_equal(res, ref, 0.0001)
416
+
417
+ A.ncol = A.ncol + 1
418
+ with test.assertRaisesRegex(ValueError, "Number of columns"):
419
+ bsr_mv(A, x, y)
420
+
421
+ A.ncol = A.ncol - 1
422
+ A.nrow = A.nrow - 1
423
+ with test.assertRaisesRegex(ValueError, "Number of rows"):
424
+ bsr_mv(A, x, y)
425
+
426
+ return test_bsr_mv
427
+
428
+
429
+ devices = get_test_devices()
430
+
431
+
432
+ class TestSparse(unittest.TestCase):
433
+ pass
434
+
435
+
436
+ add_function_test(TestSparse, "test_csr_from_triplets", test_csr_from_triplets, devices=devices)
437
+ add_function_test(TestSparse, "test_bsr_from_triplets", test_bsr_from_triplets, devices=devices)
438
+ add_function_test(TestSparse, "test_bsr_get_diag", test_bsr_get_set_diag, devices=devices)
439
+ add_function_test(TestSparse, "test_bsr_copy_scale", test_bsr_copy_scale, devices=devices)
440
+
441
+ add_function_test(TestSparse, "test_csr_transpose", make_test_bsr_transpose((1, 1), wp.float32), devices=devices)
442
+ add_function_test(TestSparse, "test_bsr_transpose_1_3", make_test_bsr_transpose((1, 3), wp.float32), devices=devices)
443
+ add_function_test(TestSparse, "test_bsr_transpose_3_3", make_test_bsr_transpose((3, 3), wp.float64), devices=devices)
248
444
 
249
- add_function_test(TestSparse, "test_csr_axpy", make_test_bsr_axpy((1, 1), wp.float32), devices=devices)
250
- add_function_test(TestSparse, "test_bsr_axpy_1_3", make_test_bsr_axpy((1, 3), wp.float32), devices=devices)
251
- add_function_test(TestSparse, "test_bsr_axpy_3_3", make_test_bsr_axpy((3, 3), wp.float64), devices=devices)
445
+ add_function_test(TestSparse, "test_csr_axpy", make_test_bsr_axpy((1, 1), wp.float32), devices=devices)
446
+ add_function_test(TestSparse, "test_bsr_axpy_1_3", make_test_bsr_axpy((1, 3), wp.float32), devices=devices)
447
+ add_function_test(TestSparse, "test_bsr_axpy_3_3", make_test_bsr_axpy((3, 3), wp.float64), devices=devices)
252
448
 
253
- add_function_test(TestSparse, "test_csr_mm", make_test_bsr_mm((1, 1), wp.float32), devices=devices)
254
- add_function_test(TestSparse, "test_bsr_mm_1_3", make_test_bsr_mm((1, 3), wp.float32), devices=devices)
255
- add_function_test(TestSparse, "test_bsr_mm_3_3", make_test_bsr_mm((3, 3), wp.float64), devices=devices)
449
+ add_function_test(TestSparse, "test_csr_mm", make_test_bsr_mm((1, 1), wp.float32), devices=devices)
450
+ add_function_test(TestSparse, "test_bsr_mm_1_3", make_test_bsr_mm((1, 3), wp.float32), devices=devices)
451
+ add_function_test(TestSparse, "test_bsr_mm_3_3", make_test_bsr_mm((3, 3), wp.float64), devices=devices)
256
452
 
257
- return TestSparse
453
+ add_function_test(TestSparse, "test_csr_mv", make_test_bsr_mv((1, 1), wp.float32), devices=devices)
454
+ add_function_test(TestSparse, "test_bsr_mv_1_3", make_test_bsr_mv((1, 3), wp.float32), devices=devices)
455
+ add_function_test(TestSparse, "test_bsr_mv_3_3", make_test_bsr_mv((3, 3), wp.float64), devices=devices)
258
456
 
259
457
 
260
458
  if __name__ == "__main__":
261
- c = register(unittest.TestCase)
459
+ wp.build.clear_kernel_cache()
262
460
  unittest.main(verbosity=2)