warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2889 @@
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
10
+ import numpy as np
11
+
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
14
+
15
+ wp.init()
16
+
17
+ np_signed_int_types = [
18
+ np.int8,
19
+ np.int16,
20
+ np.int32,
21
+ np.int64,
22
+ np.byte,
23
+ ]
24
+
25
+ np_unsigned_int_types = [
26
+ np.uint8,
27
+ np.uint16,
28
+ np.uint32,
29
+ np.uint64,
30
+ np.ubyte,
31
+ ]
32
+
33
+ np_int_types = np_signed_int_types + np_unsigned_int_types
34
+
35
+ np_float_types = [np.float16, np.float32, np.float64]
36
+
37
+ np_scalar_types = np_int_types + np_float_types
38
+
39
+
40
+ def randvals(rng, shape, dtype):
41
+ if dtype in np_float_types:
42
+ return rng.standard_normal(size=shape).astype(dtype)
43
+ elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
44
+ return rng.integers(1, high=3, size=shape, dtype=dtype)
45
+ return rng.integers(1, high=5, size=shape, dtype=dtype)
46
+
47
+
48
+ kernel_cache = dict()
49
+
50
+
51
+ def getkernel(func, suffix=""):
52
+ key = func.__name__ + "_" + suffix
53
+ if key not in kernel_cache:
54
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
55
+ return kernel_cache[key]
56
+
57
+
58
+ def get_select_kernel(dtype):
59
+ def output_select_kernel_fn(
60
+ input: wp.array(dtype=dtype),
61
+ index: int,
62
+ out: wp.array(dtype=dtype),
63
+ ):
64
+ out[0] = input[index]
65
+
66
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
67
+
68
+ wp.launch(kernel, dim=1, inputs=[])
69
+
70
+
71
+ def test_arrays(test, device, dtype):
72
+ rng = np.random.default_rng(123)
73
+
74
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
75
+
76
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
77
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
78
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
79
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
80
+ mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
81
+
82
+ v2_np = randvals(rng, [10, 2, 2], dtype)
83
+ v3_np = randvals(rng, [10, 3, 3], dtype)
84
+ v4_np = randvals(rng, [10, 4, 4], dtype)
85
+ v5_np = randvals(rng, [10, 5, 5], dtype)
86
+ v32_np = randvals(rng, [10, 3, 2], dtype)
87
+
88
+ v2 = wp.array(v2_np, dtype=mat22, requires_grad=True, device=device)
89
+ v3 = wp.array(v3_np, dtype=mat33, requires_grad=True, device=device)
90
+ v4 = wp.array(v4_np, dtype=mat44, requires_grad=True, device=device)
91
+ v5 = wp.array(v5_np, dtype=mat55, requires_grad=True, device=device)
92
+ v32 = wp.array(v32_np, dtype=mat32, requires_grad=True, device=device)
93
+
94
+ assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
95
+ assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
96
+ assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
97
+ assert_np_equal(v5.numpy(), v5_np, tol=1.0e-6)
98
+ assert_np_equal(v32.numpy(), v32_np, tol=1.0e-6)
99
+
100
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
101
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
102
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
103
+
104
+ v2 = wp.array(v2_np, dtype=mat22, requires_grad=True, device=device)
105
+ v3 = wp.array(v3_np, dtype=mat33, requires_grad=True, device=device)
106
+ v4 = wp.array(v4_np, dtype=mat44, requires_grad=True, device=device)
107
+
108
+ assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
109
+ assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
110
+ assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
111
+
112
+
113
+ def test_components(test, device, dtype):
114
+ # test accessing matrix components from Python - this is especially important
115
+ # for float16, which requires special handling internally
116
+
117
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
118
+ mat23 = wp.types.matrix(shape=(2, 3), dtype=wptype)
119
+
120
+ m = mat23(1, 2, 3, 4, 5, 6)
121
+
122
+ # test __getitem__ for row vectors
123
+ r0 = m[0]
124
+ r1 = m[1]
125
+ test.assertEqual(r0[0], 1)
126
+ test.assertEqual(r0[1], 2)
127
+ test.assertEqual(r0[2], 3)
128
+ test.assertEqual(r1[0], 4)
129
+ test.assertEqual(r1[1], 5)
130
+ test.assertEqual(r1[2], 6)
131
+
132
+ # test __getitem__ for individual components
133
+ test.assertEqual(m[0, 0], 1)
134
+ test.assertEqual(m[0, 1], 2)
135
+ test.assertEqual(m[0, 2], 3)
136
+ test.assertEqual(m[1, 0], 4)
137
+ test.assertEqual(m[1, 1], 5)
138
+ test.assertEqual(m[1, 2], 6)
139
+
140
+ # test __setitem__ for row vectors
141
+ m[0] = [7, 8, 9]
142
+ m[1] = [10, 11, 12]
143
+ test.assertEqual(m[0, 0], 7)
144
+ test.assertEqual(m[0, 1], 8)
145
+ test.assertEqual(m[0, 2], 9)
146
+ test.assertEqual(m[1, 0], 10)
147
+ test.assertEqual(m[1, 1], 11)
148
+ test.assertEqual(m[1, 2], 12)
149
+
150
+ # test __setitem__ for individual components
151
+ m[0, 0] = 13
152
+ m[0, 1] = 14
153
+ m[0, 2] = 15
154
+ m[1, 0] = 16
155
+ m[1, 1] = 17
156
+ m[1, 2] = 18
157
+ test.assertEqual(m[0, 0], 13)
158
+ test.assertEqual(m[0, 1], 14)
159
+ test.assertEqual(m[0, 2], 15)
160
+ test.assertEqual(m[1, 0], 16)
161
+ test.assertEqual(m[1, 1], 17)
162
+ test.assertEqual(m[1, 2], 18)
163
+
164
+
165
+ def test_constants(test, device, dtype, register_kernels=False):
166
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
167
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
168
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
169
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
170
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
171
+ mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
172
+
173
+ cm22 = wp.constant(mat22(22))
174
+ cm33 = wp.constant(mat33(33))
175
+ cm44 = wp.constant(mat44(44))
176
+ cm55 = wp.constant(mat55(55))
177
+ cm32 = wp.constant(mat32(32))
178
+
179
+ def check_matrix_constants():
180
+ wp.expect_eq(cm22, mat22(wptype(22)))
181
+ wp.expect_eq(cm33, mat33(wptype(33)))
182
+ wp.expect_eq(cm44, mat44(wptype(44)))
183
+ wp.expect_eq(cm55, mat55(wptype(55)))
184
+ wp.expect_eq(cm32, mat32(wptype(32)))
185
+
186
+ kernel = getkernel(check_matrix_constants, suffix=dtype.__name__)
187
+
188
+ if register_kernels:
189
+ return
190
+
191
+
192
+ def test_constructors(test, device, dtype, register_kernels=False):
193
+ rng = np.random.default_rng(123)
194
+
195
+ tol = {
196
+ np.float16: 1.0e-3,
197
+ np.float32: 1.0e-6,
198
+ np.float64: 1.0e-8,
199
+ }.get(dtype, 0)
200
+
201
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
202
+ vec2 = wp.types.vector(length=2, dtype=wptype)
203
+ vec3 = wp.types.vector(length=3, dtype=wptype)
204
+ vec4 = wp.types.vector(length=4, dtype=wptype)
205
+ vec5 = wp.types.vector(length=5, dtype=wptype)
206
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
207
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
208
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
209
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
210
+
211
+ output_select_kernel = get_select_kernel(wptype)
212
+
213
+ def check_scalar_mat_constructor(
214
+ input: wp.array(dtype=wptype),
215
+ outcomponents: wp.array(dtype=wptype),
216
+ ):
217
+ # multiply outputs by 2 so we've got something to backpropagate:
218
+ m2result = wptype(2) * mat22(input[0])
219
+ m3result = wptype(2) * mat33(input[0])
220
+ m4result = wptype(2) * mat44(input[0])
221
+ m5result = wptype(2) * mat55(input[0])
222
+
223
+ idx = 0
224
+ for i in range(2):
225
+ for j in range(2):
226
+ outcomponents[idx] = m2result[i, j]
227
+ idx = idx + 1
228
+
229
+ for i in range(3):
230
+ for j in range(3):
231
+ outcomponents[idx] = m3result[i, j]
232
+ idx = idx + 1
233
+
234
+ for i in range(4):
235
+ for j in range(4):
236
+ outcomponents[idx] = m4result[i, j]
237
+ idx = idx + 1
238
+
239
+ for i in range(5):
240
+ for j in range(5):
241
+ outcomponents[idx] = m5result[i, j]
242
+ idx = idx + 1
243
+
244
+ def check_component_mat_constructor(
245
+ input: wp.array(dtype=wptype),
246
+ outcomponents: wp.array(dtype=wptype),
247
+ ):
248
+ # multiply outputs by 2 so we've got something to backpropagate:
249
+ m2result = wptype(2) * mat22(input[0], input[1], input[2], input[3])
250
+ m3result = wptype(2) * mat33(
251
+ input[4],
252
+ input[5],
253
+ input[6],
254
+ input[7],
255
+ input[8],
256
+ input[9],
257
+ input[10],
258
+ input[11],
259
+ input[12],
260
+ )
261
+ m4result = wptype(2) * mat44(
262
+ input[13],
263
+ input[14],
264
+ input[15],
265
+ input[16],
266
+ input[17],
267
+ input[18],
268
+ input[19],
269
+ input[20],
270
+ input[21],
271
+ input[22],
272
+ input[23],
273
+ input[24],
274
+ input[25],
275
+ input[26],
276
+ input[27],
277
+ input[28],
278
+ )
279
+ m5result = wptype(2) * mat55(
280
+ input[29],
281
+ input[30],
282
+ input[31],
283
+ input[32],
284
+ input[33],
285
+ input[34],
286
+ input[35],
287
+ input[36],
288
+ input[37],
289
+ input[38],
290
+ input[39],
291
+ input[40],
292
+ input[41],
293
+ input[42],
294
+ input[43],
295
+ input[44],
296
+ input[45],
297
+ input[46],
298
+ input[47],
299
+ input[48],
300
+ input[49],
301
+ input[50],
302
+ input[51],
303
+ input[52],
304
+ input[53],
305
+ )
306
+
307
+ idx = 0
308
+ for i in range(2):
309
+ for j in range(2):
310
+ outcomponents[idx] = m2result[i, j]
311
+ idx = idx + 1
312
+
313
+ for i in range(3):
314
+ for j in range(3):
315
+ outcomponents[idx] = m3result[i, j]
316
+ idx = idx + 1
317
+
318
+ for i in range(4):
319
+ for j in range(4):
320
+ outcomponents[idx] = m4result[i, j]
321
+ idx = idx + 1
322
+
323
+ for i in range(5):
324
+ for j in range(5):
325
+ outcomponents[idx] = m5result[i, j]
326
+ idx = idx + 1
327
+
328
+ def check_vector_mat_constructor(
329
+ input: wp.array(dtype=wptype),
330
+ outcomponents: wp.array(dtype=wptype),
331
+ ):
332
+ # multiply outputs by 2 so we've got something to backpropagate:
333
+ m2result = wptype(2) * mat22(vec2(input[0], input[2]), vec2(input[1], input[3]))
334
+ m3result = wptype(2) * mat33(
335
+ vec3(input[4], input[7], input[10]),
336
+ vec3(input[5], input[8], input[11]),
337
+ vec3(input[6], input[9], input[12]),
338
+ )
339
+ m4result = wptype(2) * mat44(
340
+ vec4(input[13], input[17], input[21], input[25]),
341
+ vec4(input[14], input[18], input[22], input[26]),
342
+ vec4(input[15], input[19], input[23], input[27]),
343
+ vec4(input[16], input[20], input[24], input[28]),
344
+ )
345
+ m5result = wptype(2) * mat55(
346
+ vec5(input[29], input[34], input[39], input[44], input[49]),
347
+ vec5(input[30], input[35], input[40], input[45], input[50]),
348
+ vec5(input[31], input[36], input[41], input[46], input[51]),
349
+ vec5(input[32], input[37], input[42], input[47], input[52]),
350
+ vec5(input[33], input[38], input[43], input[48], input[53]),
351
+ )
352
+
353
+ idx = 0
354
+ for i in range(2):
355
+ for j in range(2):
356
+ outcomponents[idx] = m2result[i, j]
357
+ idx = idx + 1
358
+
359
+ for i in range(3):
360
+ for j in range(3):
361
+ outcomponents[idx] = m3result[i, j]
362
+ idx = idx + 1
363
+
364
+ for i in range(4):
365
+ for j in range(4):
366
+ outcomponents[idx] = m4result[i, j]
367
+ idx = idx + 1
368
+
369
+ for i in range(5):
370
+ for j in range(5):
371
+ outcomponents[idx] = m5result[i, j]
372
+ idx = idx + 1
373
+
374
+ kernel = getkernel(check_scalar_mat_constructor, suffix=dtype.__name__)
375
+ compkernel = getkernel(check_component_mat_constructor, suffix=dtype.__name__)
376
+ veckernel = getkernel(check_vector_mat_constructor, suffix=dtype.__name__)
377
+
378
+ if register_kernels:
379
+ return
380
+
381
+ input = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
382
+ val = input.numpy()[0]
383
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
384
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
385
+
386
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
387
+
388
+ assert_np_equal(outcomponents.numpy()[:4], 2 * val * np.ones(2 * 2), tol=tol)
389
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * val * np.ones(3 * 3), tol=tol)
390
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * val * np.ones(4 * 4), tol=tol)
391
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * val * np.ones(5 * 5), tol=tol)
392
+
393
+ if dtype in np_float_types:
394
+ for idx in range(len(outcomponents)):
395
+ tape = wp.Tape()
396
+ with tape:
397
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
398
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
399
+ tape.backward(loss=out)
400
+ test.assertEqual(tape.gradients[input].numpy()[0], 2)
401
+ tape.zero()
402
+
403
+ input = wp.array(randvals(rng, [2 * 2 + 3 * 3 + 4 * 4 + 5 * 5], dtype), requires_grad=True, device=device)
404
+
405
+ wp.launch(compkernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
406
+ assert_np_equal(2 * input.numpy(), outcomponents.numpy(), tol=10 * tol)
407
+
408
+ if dtype in np_float_types:
409
+ for idx in range(len(outcomponents)):
410
+ tape = wp.Tape()
411
+ with tape:
412
+ wp.launch(compkernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
413
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
414
+ tape.backward(loss=out)
415
+ expectedgrads = np.zeros(len(input))
416
+ expectedgrads[idx] = 2
417
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
418
+ tape.zero()
419
+
420
+ wp.launch(veckernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
421
+ assert_np_equal(2 * input.numpy(), outcomponents.numpy(), tol=10 * tol)
422
+
423
+ if dtype in np_float_types:
424
+ for idx in range(len(outcomponents)):
425
+ tape = wp.Tape()
426
+ with tape:
427
+ wp.launch(veckernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
428
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
429
+ tape.backward(loss=out)
430
+ expectedgrads = np.zeros(len(input))
431
+ expectedgrads[idx] = 2
432
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
433
+ tape.zero()
434
+
435
+
436
+ def test_anon_type_instance(test, device, dtype, register_kernels=False):
437
+ rng = np.random.default_rng(123)
438
+
439
+ tol = {
440
+ np.float16: 5.0e-3,
441
+ np.float32: 1.0e-6,
442
+ np.float64: 1.0e-8,
443
+ }.get(dtype, 0)
444
+
445
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
446
+
447
+ def check_scalar_init(
448
+ input: wp.array(dtype=wptype),
449
+ output: wp.array(dtype=wptype),
450
+ ):
451
+ m2result = wp.matrix(input[0], shape=(2, 2))
452
+ m3result = wp.matrix(input[1], shape=(3, 3))
453
+ m4result = wp.matrix(input[2], shape=(4, 4))
454
+ m5result = wp.matrix(input[3], shape=(5, 5))
455
+ m32result = wp.matrix(input[4], shape=(3, 2))
456
+
457
+ idx = 0
458
+ for i in range(2):
459
+ for j in range(2):
460
+ output[idx] = wptype(2) * m2result[i, j]
461
+ idx = idx + 1
462
+ for i in range(3):
463
+ for j in range(3):
464
+ output[idx] = wptype(2) * m3result[i, j]
465
+ idx = idx + 1
466
+ for i in range(4):
467
+ for j in range(4):
468
+ output[idx] = wptype(2) * m4result[i, j]
469
+ idx = idx + 1
470
+ for i in range(5):
471
+ for j in range(5):
472
+ output[idx] = wptype(2) * m5result[i, j]
473
+ idx = idx + 1
474
+ for i in range(3):
475
+ for j in range(2):
476
+ output[idx] = wptype(2) * m32result[i, j]
477
+ idx = idx + 1
478
+
479
+ def check_component_init(
480
+ input: wp.array(dtype=wptype),
481
+ output: wp.array(dtype=wptype),
482
+ ):
483
+ m2result = wp.matrix(input[0], input[1], input[2], input[3], shape=(2, 2))
484
+ m3result = wp.matrix(
485
+ input[4], input[5], input[6], input[7], input[8], input[9], input[10], input[11], input[12], shape=(3, 3)
486
+ )
487
+ m4result = wp.matrix(
488
+ input[13],
489
+ input[14],
490
+ input[15],
491
+ input[16],
492
+ input[17],
493
+ input[18],
494
+ input[19],
495
+ input[20],
496
+ input[21],
497
+ input[22],
498
+ input[23],
499
+ input[24],
500
+ input[25],
501
+ input[26],
502
+ input[27],
503
+ input[28],
504
+ shape=(4, 4),
505
+ )
506
+ m5result = wp.matrix(
507
+ input[29],
508
+ input[30],
509
+ input[31],
510
+ input[32],
511
+ input[33],
512
+ input[34],
513
+ input[35],
514
+ input[36],
515
+ input[37],
516
+ input[38],
517
+ input[39],
518
+ input[40],
519
+ input[41],
520
+ input[42],
521
+ input[43],
522
+ input[44],
523
+ input[45],
524
+ input[46],
525
+ input[47],
526
+ input[48],
527
+ input[49],
528
+ input[50],
529
+ input[51],
530
+ input[52],
531
+ input[53],
532
+ shape=(5, 5),
533
+ )
534
+ m32result = wp.matrix(input[54], input[55], input[56], input[57], input[58], input[59], shape=(3, 2))
535
+
536
+ idx = 0
537
+ for i in range(2):
538
+ for j in range(2):
539
+ output[idx] = wptype(2) * m2result[i, j]
540
+ idx = idx + 1
541
+ for i in range(3):
542
+ for j in range(3):
543
+ output[idx] = wptype(2) * m3result[i, j]
544
+ idx = idx + 1
545
+ for i in range(4):
546
+ for j in range(4):
547
+ output[idx] = wptype(2) * m4result[i, j]
548
+ idx = idx + 1
549
+ for i in range(5):
550
+ for j in range(5):
551
+ output[idx] = wptype(2) * m5result[i, j]
552
+ idx = idx + 1
553
+ for i in range(3):
554
+ for j in range(2):
555
+ output[idx] = wptype(2) * m32result[i, j]
556
+ idx = idx + 1
557
+
558
+ scalar_kernel = getkernel(check_scalar_init, suffix=dtype.__name__)
559
+ component_kernel = getkernel(check_component_init, suffix=dtype.__name__)
560
+ output_select_kernel = get_select_kernel(wptype)
561
+
562
+ if register_kernels:
563
+ return
564
+
565
+ input = wp.array(randvals(rng, [5], dtype), requires_grad=True, device=device)
566
+ output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2, dtype=wptype, requires_grad=True, device=device)
567
+
568
+ wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
569
+
570
+ assert_np_equal(output.numpy()[:4], 2 * np.array([input.numpy()[0]] * 2 * 2), tol=1.0e-6)
571
+ assert_np_equal(output.numpy()[4:13], 2 * np.array([input.numpy()[1]] * 3 * 3), tol=1.0e-6)
572
+ assert_np_equal(output.numpy()[13:29], 2 * np.array([input.numpy()[2]] * 4 * 4), tol=1.0e-6)
573
+ assert_np_equal(output.numpy()[29:54], 2 * np.array([input.numpy()[3]] * 5 * 5), tol=1.0e-6)
574
+ assert_np_equal(output.numpy()[54:], 2 * np.array([input.numpy()[4]] * 3 * 2), tol=1.0e-6)
575
+
576
+ if dtype in np_float_types:
577
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
578
+ for i in range(len(output)):
579
+ tape = wp.Tape()
580
+ with tape:
581
+ wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
582
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
583
+
584
+ tape.backward(loss=out)
585
+ expected = np.zeros_like(input.numpy())
586
+ if i < 4:
587
+ expected[0] = 2
588
+ elif i < 13:
589
+ expected[1] = 2
590
+ elif i < 29:
591
+ expected[2] = 2
592
+ elif i < 54:
593
+ expected[3] = 2
594
+ else:
595
+ expected[4] = 2
596
+
597
+ assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
598
+
599
+ tape.reset()
600
+ tape.zero()
601
+
602
+ input = wp.array(randvals(rng, [2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2], dtype), requires_grad=True, device=device)
603
+ output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2, dtype=wptype, requires_grad=True, device=device)
604
+
605
+ wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
606
+
607
+ assert_np_equal(output.numpy(), 2 * input.numpy(), tol=1.0e-6)
608
+
609
+ if dtype in np_float_types:
610
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
611
+ for i in range(len(output)):
612
+ tape = wp.Tape()
613
+ with tape:
614
+ wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
615
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
616
+
617
+ tape.backward(loss=out)
618
+ expected = np.zeros_like(input.numpy())
619
+ expected[i] = 2
620
+
621
+ assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
622
+
623
+ tape.reset()
624
+ tape.zero()
625
+
626
+
627
+ def test_identity(test, device, dtype, register_kernels=False):
628
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
629
+
630
+ def check_identity_mat(
631
+ output: wp.array(dtype=wptype),
632
+ ):
633
+ m2result = wp.identity(dtype=wptype, n=2)
634
+ m3result = wp.identity(dtype=wptype, n=3)
635
+ m4result = wp.identity(dtype=wptype, n=4)
636
+ m5result = wp.identity(dtype=wptype, n=5)
637
+
638
+ idx = 0
639
+ for i in range(2):
640
+ for j in range(2):
641
+ output[idx] = wptype(2) * m2result[i, j]
642
+ idx = idx + 1
643
+ for i in range(3):
644
+ for j in range(3):
645
+ output[idx] = wptype(2) * m3result[i, j]
646
+ idx = idx + 1
647
+ for i in range(4):
648
+ for j in range(4):
649
+ output[idx] = wptype(2) * m4result[i, j]
650
+ idx = idx + 1
651
+ for i in range(5):
652
+ for j in range(5):
653
+ output[idx] = wptype(2) * m5result[i, j]
654
+ idx = idx + 1
655
+
656
+ id_kernel = getkernel(check_identity_mat, suffix=dtype.__name__)
657
+
658
+ if register_kernels:
659
+ return
660
+
661
+ output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
662
+ wp.launch(id_kernel, dim=1, inputs=[], outputs=[output], device=device)
663
+ assert_np_equal(output.numpy()[:4], 2 * np.eye(2), tol=1.0e-6)
664
+ assert_np_equal(output.numpy()[4:13], 2 * np.eye(3), tol=1.0e-6)
665
+ assert_np_equal(output.numpy()[13:29], 2 * np.eye(4), tol=1.0e-6)
666
+ assert_np_equal(output.numpy()[29:], 2 * np.eye(5), tol=1.0e-6)
667
+
668
+
669
+ def test_indexing(test, device, dtype, register_kernels=False):
670
+ rng = np.random.default_rng(123)
671
+
672
+ tol = {
673
+ np.float16: 1.0e-3,
674
+ np.float32: 1.0e-6,
675
+ np.float64: 1.0e-8,
676
+ }.get(dtype, 0)
677
+
678
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
679
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
680
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
681
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
682
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
683
+
684
+ output_select_kernel = get_select_kernel(wptype)
685
+
686
+ def check_mat_indexing(
687
+ m2: wp.array(dtype=mat22),
688
+ m3: wp.array(dtype=mat33),
689
+ m4: wp.array(dtype=mat44),
690
+ m5: wp.array(dtype=mat55),
691
+ outcomponents: wp.array(dtype=wptype),
692
+ ):
693
+ # multiply outputs by 2 so we've got something to backpropagate:
694
+ idx = 0
695
+ for i in range(2):
696
+ for j in range(2):
697
+ outcomponents[idx] = wptype(2) * m2[0][i, j]
698
+ idx = idx + 1
699
+
700
+ for i in range(3):
701
+ for j in range(3):
702
+ outcomponents[idx] = wptype(2) * m3[0][i, j]
703
+ idx = idx + 1
704
+
705
+ for i in range(4):
706
+ for j in range(4):
707
+ outcomponents[idx] = wptype(2) * m4[0][i, j]
708
+ idx = idx + 1
709
+
710
+ for i in range(5):
711
+ for j in range(5):
712
+ outcomponents[idx] = wptype(2) * m5[0][i, j]
713
+ idx = idx + 1
714
+
715
+ kernel = getkernel(check_mat_indexing, suffix=dtype.__name__)
716
+
717
+ if register_kernels:
718
+ return
719
+
720
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
721
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
722
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
723
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
724
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
725
+
726
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
727
+
728
+ assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy().reshape(-1), tol=tol)
729
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy().reshape(-1), tol=tol)
730
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy().reshape(-1), tol=tol)
731
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy().reshape(-1), tol=tol)
732
+
733
+ if dtype in np_float_types:
734
+ idx = 0
735
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
736
+ for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
737
+ for i in range(dim):
738
+ for j in range(dim):
739
+ tape = wp.Tape()
740
+ with tape:
741
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
742
+ wp.launch(
743
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
744
+ )
745
+ tape.backward(loss=out)
746
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
747
+ expectedresult[i, j] = 2
748
+ assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
749
+ tape.zero()
750
+ idx = idx + 1
751
+
752
+
753
+ def test_equality(test, device, dtype, register_kernels=False):
754
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
755
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
756
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
757
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
758
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
759
+
760
+ def check_mat_equality():
761
+ wp.expect_eq(
762
+ mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
763
+ mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
764
+ )
765
+ wp.expect_neq(
766
+ mat22(wptype(1.0), wptype(2.0), wptype(3.0), -wptype(4.0)),
767
+ mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
768
+ )
769
+
770
+ wp.expect_eq(
771
+ mat33(
772
+ wptype(1.0),
773
+ wptype(2.0),
774
+ wptype(3.0),
775
+ wptype(4.0),
776
+ wptype(5.0),
777
+ wptype(6.0),
778
+ wptype(7.0),
779
+ wptype(8.0),
780
+ wptype(9.0),
781
+ ),
782
+ mat33(
783
+ wptype(1.0),
784
+ wptype(2.0),
785
+ wptype(3.0),
786
+ wptype(4.0),
787
+ wptype(5.0),
788
+ wptype(6.0),
789
+ wptype(7.0),
790
+ wptype(8.0),
791
+ wptype(9.0),
792
+ ),
793
+ )
794
+ wp.expect_neq(
795
+ mat33(
796
+ wptype(1.0),
797
+ wptype(2.0),
798
+ wptype(3.0),
799
+ wptype(4.0),
800
+ wptype(5.0),
801
+ wptype(6.0),
802
+ wptype(7.0),
803
+ wptype(8.0),
804
+ wptype(9.0),
805
+ ),
806
+ mat33(
807
+ wptype(1.0),
808
+ wptype(2.0),
809
+ wptype(3.0),
810
+ -wptype(4.0),
811
+ wptype(5.0),
812
+ wptype(6.0),
813
+ wptype(7.0),
814
+ wptype(8.0),
815
+ wptype(9.0),
816
+ ),
817
+ )
818
+
819
+ wp.expect_eq(
820
+ mat44(
821
+ wptype(1.0),
822
+ wptype(2.0),
823
+ wptype(3.0),
824
+ wptype(4.0),
825
+ wptype(5.0),
826
+ wptype(6.0),
827
+ wptype(7.0),
828
+ wptype(8.0),
829
+ wptype(9.0),
830
+ wptype(10.0),
831
+ wptype(11.0),
832
+ wptype(12.0),
833
+ wptype(13.0),
834
+ wptype(14.0),
835
+ wptype(15.0),
836
+ wptype(16.0),
837
+ ),
838
+ mat44(
839
+ wptype(1.0),
840
+ wptype(2.0),
841
+ wptype(3.0),
842
+ wptype(4.0),
843
+ wptype(5.0),
844
+ wptype(6.0),
845
+ wptype(7.0),
846
+ wptype(8.0),
847
+ wptype(9.0),
848
+ wptype(10.0),
849
+ wptype(11.0),
850
+ wptype(12.0),
851
+ wptype(13.0),
852
+ wptype(14.0),
853
+ wptype(15.0),
854
+ wptype(16.0),
855
+ ),
856
+ )
857
+
858
+ wp.expect_neq(
859
+ mat44(
860
+ wptype(1.0),
861
+ wptype(2.0),
862
+ wptype(3.0),
863
+ wptype(4.0),
864
+ wptype(5.0),
865
+ wptype(6.0),
866
+ wptype(7.0),
867
+ wptype(8.0),
868
+ wptype(9.0),
869
+ wptype(10.0),
870
+ wptype(11.0),
871
+ wptype(12.0),
872
+ wptype(13.0),
873
+ wptype(14.0),
874
+ wptype(15.0),
875
+ wptype(16.0),
876
+ ),
877
+ mat44(
878
+ -wptype(1.0),
879
+ wptype(2.0),
880
+ wptype(3.0),
881
+ wptype(4.0),
882
+ wptype(5.0),
883
+ wptype(6.0),
884
+ wptype(7.0),
885
+ wptype(8.0),
886
+ wptype(9.0),
887
+ wptype(10.0),
888
+ wptype(11.0),
889
+ wptype(12.0),
890
+ wptype(13.0),
891
+ wptype(14.0),
892
+ wptype(15.0),
893
+ wptype(16.0),
894
+ ),
895
+ )
896
+
897
+ wp.expect_eq(
898
+ mat55(
899
+ wptype(1.0),
900
+ wptype(2.0),
901
+ wptype(3.0),
902
+ wptype(4.0),
903
+ wptype(5.0),
904
+ wptype(6.0),
905
+ wptype(7.0),
906
+ wptype(8.0),
907
+ wptype(9.0),
908
+ wptype(10.0),
909
+ wptype(11.0),
910
+ wptype(12.0),
911
+ wptype(13.0),
912
+ wptype(14.0),
913
+ wptype(15.0),
914
+ wptype(16.0),
915
+ wptype(17.0),
916
+ wptype(18.0),
917
+ wptype(19.0),
918
+ wptype(20.0),
919
+ wptype(21.0),
920
+ wptype(22.0),
921
+ wptype(23.0),
922
+ wptype(24.0),
923
+ wptype(25.0),
924
+ ),
925
+ mat55(
926
+ wptype(1.0),
927
+ wptype(2.0),
928
+ wptype(3.0),
929
+ wptype(4.0),
930
+ wptype(5.0),
931
+ wptype(6.0),
932
+ wptype(7.0),
933
+ wptype(8.0),
934
+ wptype(9.0),
935
+ wptype(10.0),
936
+ wptype(11.0),
937
+ wptype(12.0),
938
+ wptype(13.0),
939
+ wptype(14.0),
940
+ wptype(15.0),
941
+ wptype(16.0),
942
+ wptype(17.0),
943
+ wptype(18.0),
944
+ wptype(19.0),
945
+ wptype(20.0),
946
+ wptype(21.0),
947
+ wptype(22.0),
948
+ wptype(23.0),
949
+ wptype(24.0),
950
+ wptype(25.0),
951
+ ),
952
+ )
953
+
954
+ wp.expect_neq(
955
+ mat55(
956
+ wptype(1.0),
957
+ wptype(2.0),
958
+ wptype(3.0),
959
+ wptype(4.0),
960
+ wptype(5.0),
961
+ wptype(6.0),
962
+ wptype(7.0),
963
+ wptype(8.0),
964
+ wptype(9.0),
965
+ wptype(10.0),
966
+ wptype(11.0),
967
+ wptype(12.0),
968
+ wptype(13.0),
969
+ wptype(14.0),
970
+ wptype(15.0),
971
+ wptype(16.0),
972
+ wptype(17.0),
973
+ wptype(18.0),
974
+ wptype(19.0),
975
+ wptype(20.0),
976
+ wptype(21.0),
977
+ wptype(22.0),
978
+ wptype(23.0),
979
+ wptype(24.0),
980
+ wptype(25.0),
981
+ ),
982
+ mat55(
983
+ wptype(1.0),
984
+ wptype(2.0),
985
+ wptype(3.0),
986
+ wptype(4.0),
987
+ wptype(5.0),
988
+ wptype(6.0),
989
+ wptype(7.0),
990
+ wptype(8.0),
991
+ wptype(9.0),
992
+ wptype(10.0),
993
+ wptype(11.0),
994
+ wptype(12.0),
995
+ wptype(13.0),
996
+ wptype(14.0),
997
+ wptype(15.0),
998
+ wptype(16.0),
999
+ -wptype(17.0),
1000
+ wptype(18.0),
1001
+ wptype(19.0),
1002
+ wptype(20.0),
1003
+ wptype(21.0),
1004
+ wptype(22.0),
1005
+ wptype(23.0),
1006
+ wptype(24.0),
1007
+ wptype(25.0),
1008
+ ),
1009
+ )
1010
+
1011
+ kernel = getkernel(check_mat_equality, suffix=dtype.__name__)
1012
+
1013
+ if register_kernels:
1014
+ return
1015
+
1016
+ wp.launch(kernel, dim=1, inputs=[], outputs=[], device=device)
1017
+
1018
+
1019
+ def test_scalar_multiplication(test, device, dtype, register_kernels=False):
1020
+ rng = np.random.default_rng(123)
1021
+
1022
+ tol = {
1023
+ np.float16: 1.0e-2,
1024
+ np.float32: 1.0e-6,
1025
+ np.float64: 1.0e-8,
1026
+ }.get(dtype, 0)
1027
+
1028
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1029
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1030
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1031
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1032
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1033
+
1034
+ output_select_kernel = get_select_kernel(wptype)
1035
+
1036
+ def check_mat_scalar_mul(
1037
+ s: wp.array(dtype=wptype),
1038
+ m2: wp.array(dtype=mat22),
1039
+ m3: wp.array(dtype=mat33),
1040
+ m4: wp.array(dtype=mat44),
1041
+ m5: wp.array(dtype=mat55),
1042
+ outcomponents: wp.array(dtype=wptype),
1043
+ outcomponents_rightmul: wp.array(dtype=wptype),
1044
+ ):
1045
+ m2result = s[0] * m2[0]
1046
+ m3result = s[0] * m3[0]
1047
+ m4result = s[0] * m4[0]
1048
+ m5result = s[0] * m5[0]
1049
+
1050
+ m2resultright = m2[0] * s[0]
1051
+ m3resultright = m3[0] * s[0]
1052
+ m4resultright = m4[0] * s[0]
1053
+ m5resultright = m5[0] * s[0]
1054
+
1055
+ m2result_2 = s[0] * m2[0]
1056
+ m3result_2 = s[0] * m3[0]
1057
+ m4result_2 = s[0] * m4[0]
1058
+ m5result_2 = s[0] * m5[0]
1059
+
1060
+ m2resultright_2 = m2[0] * s[0]
1061
+ m3resultright_2 = m3[0] * s[0]
1062
+ m4resultright_2 = m4[0] * s[0]
1063
+ m5resultright_2 = m5[0] * s[0]
1064
+
1065
+ # multiply outputs by 2 so we've got something to backpropagate:
1066
+ idx = 0
1067
+ for i in range(2):
1068
+ for j in range(2):
1069
+ outcomponents[idx] = wptype(2) * m2result[i, j]
1070
+ outcomponents_rightmul[idx] = wptype(2) * m2resultright[i, j]
1071
+ idx = idx + 1
1072
+
1073
+ for i in range(3):
1074
+ for j in range(3):
1075
+ outcomponents[idx] = wptype(2) * m3result[i, j]
1076
+ outcomponents_rightmul[idx] = wptype(2) * m3resultright[i, j]
1077
+ idx = idx + 1
1078
+
1079
+ for i in range(4):
1080
+ for j in range(4):
1081
+ outcomponents[idx] = wptype(2) * m4result[i, j]
1082
+ outcomponents_rightmul[idx] = wptype(2) * m4resultright[i, j]
1083
+ idx = idx + 1
1084
+
1085
+ for i in range(5):
1086
+ for j in range(5):
1087
+ outcomponents[idx] = wptype(2) * m5result[i, j]
1088
+ outcomponents_rightmul[idx] = wptype(2) * m5resultright[i, j]
1089
+ idx = idx + 1
1090
+
1091
+ for i in range(2):
1092
+ for j in range(2):
1093
+ outcomponents[idx] = wptype(2) * m2result_2[i, j]
1094
+ outcomponents_rightmul[idx] = wptype(2) * m2resultright_2[i, j]
1095
+ idx = idx + 1
1096
+
1097
+ for i in range(3):
1098
+ for j in range(3):
1099
+ outcomponents[idx] = wptype(2) * m3result_2[i, j]
1100
+ outcomponents_rightmul[idx] = wptype(2) * m3resultright_2[i, j]
1101
+ idx = idx + 1
1102
+
1103
+ for i in range(4):
1104
+ for j in range(4):
1105
+ outcomponents[idx] = wptype(2) * m4result_2[i, j]
1106
+ outcomponents_rightmul[idx] = wptype(2) * m4resultright_2[i, j]
1107
+ idx = idx + 1
1108
+
1109
+ for i in range(5):
1110
+ for j in range(5):
1111
+ outcomponents[idx] = wptype(2) * m5result_2[i, j]
1112
+ outcomponents_rightmul[idx] = wptype(2) * m5resultright_2[i, j]
1113
+ idx = idx + 1
1114
+
1115
+ kernel = getkernel(check_mat_scalar_mul, suffix=dtype.__name__)
1116
+
1117
+ if register_kernels:
1118
+ return
1119
+
1120
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
1121
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1122
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1123
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1124
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1125
+ outcomponents = wp.zeros(2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5), dtype=wptype, requires_grad=True, device=device)
1126
+ outcomponents_rightmul = wp.zeros(
1127
+ 2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5), dtype=wptype, requires_grad=True, device=device
1128
+ )
1129
+
1130
+ wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents, outcomponents_rightmul], device=device)
1131
+
1132
+ sval = s.numpy()[0]
1133
+ assert_np_equal(outcomponents.numpy()[:4], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1134
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1135
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1136
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1137
+
1138
+ assert_np_equal(outcomponents_rightmul.numpy()[:4], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1139
+ assert_np_equal(outcomponents_rightmul.numpy()[4:13], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1140
+ assert_np_equal(outcomponents_rightmul.numpy()[13:29], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1141
+ assert_np_equal(outcomponents_rightmul.numpy()[29:54], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1142
+
1143
+ assert_np_equal(outcomponents.numpy()[54:58], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1144
+ assert_np_equal(outcomponents.numpy()[58:67], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1145
+ assert_np_equal(outcomponents.numpy()[67:83], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1146
+ assert_np_equal(outcomponents.numpy()[83:108], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1147
+
1148
+ assert_np_equal(outcomponents_rightmul.numpy()[54:58], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1149
+ assert_np_equal(outcomponents_rightmul.numpy()[58:67], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1150
+ assert_np_equal(outcomponents_rightmul.numpy()[67:83], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1151
+ assert_np_equal(outcomponents_rightmul.numpy()[83:108], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1152
+
1153
+ if dtype in np_float_types:
1154
+ idx = 0
1155
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1156
+ for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
1157
+ for i in range(dim):
1158
+ for j in range(dim):
1159
+ # test left mul gradient:
1160
+ tape = wp.Tape()
1161
+ with tape:
1162
+ wp.launch(
1163
+ kernel,
1164
+ dim=1,
1165
+ inputs=[s, m2, m3, m4, m5],
1166
+ outputs=[outcomponents, outcomponents_rightmul],
1167
+ device=device,
1168
+ )
1169
+ wp.launch(
1170
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1171
+ )
1172
+ tape.backward(loss=out)
1173
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
1174
+ expectedresult[i, j] = 2 * sval
1175
+ assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
1176
+ assert_np_equal(tape.gradients[s].numpy()[0], 2 * input.numpy()[0, i, j], tol=10 * tol)
1177
+ tape.zero()
1178
+
1179
+ # test right mul gradient:
1180
+ tape = wp.Tape()
1181
+ with tape:
1182
+ wp.launch(
1183
+ kernel,
1184
+ dim=1,
1185
+ inputs=[s, m2, m3, m4, m5],
1186
+ outputs=[outcomponents, outcomponents_rightmul],
1187
+ device=device,
1188
+ )
1189
+ wp.launch(
1190
+ output_select_kernel,
1191
+ dim=1,
1192
+ inputs=[outcomponents_rightmul, idx],
1193
+ outputs=[out],
1194
+ device=device,
1195
+ )
1196
+ tape.backward(loss=out)
1197
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
1198
+ expectedresult[i, j] = 2 * sval
1199
+ assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
1200
+ assert_np_equal(tape.gradients[s].numpy()[0], 2 * input.numpy()[0, i, j], tol=10 * tol)
1201
+ tape.zero()
1202
+
1203
+ idx = idx + 1
1204
+
1205
+
1206
+ def test_matvec_multiplication(test, device, dtype, register_kernels=False):
1207
+ rng = np.random.default_rng(123)
1208
+
1209
+ tol = {
1210
+ np.float16: 2.0e-2,
1211
+ np.float32: 5.0e-6,
1212
+ np.float64: 1.0e-8,
1213
+ }.get(dtype, 0)
1214
+
1215
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1216
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1217
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1218
+ mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
1219
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1220
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1221
+
1222
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1223
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1224
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1225
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1226
+
1227
+ output_select_kernel = get_select_kernel(wptype)
1228
+
1229
+ def check_mat_vec_mul(
1230
+ v2: wp.array(dtype=vec2),
1231
+ v3: wp.array(dtype=vec3),
1232
+ v4: wp.array(dtype=vec4),
1233
+ v5: wp.array(dtype=vec5),
1234
+ v32: wp.array(dtype=vec2),
1235
+ m2: wp.array(dtype=mat22),
1236
+ m3: wp.array(dtype=mat33),
1237
+ m4: wp.array(dtype=mat44),
1238
+ m5: wp.array(dtype=mat55),
1239
+ m32: wp.array(dtype=mat32),
1240
+ outcomponents: wp.array(dtype=wptype),
1241
+ ):
1242
+ v2result = m2[0] * v2[0]
1243
+ v3result = m3[0] * v3[0]
1244
+ v4result = m4[0] * v4[0]
1245
+ v5result = m5[0] * v5[0]
1246
+ v32result = m32[0] * v32[0]
1247
+ v2result_2 = m2[0] @ v2[0]
1248
+ v3result_2 = m3[0] @ v3[0]
1249
+ v4result_2 = m4[0] @ v4[0]
1250
+ v5result_2 = m5[0] @ v5[0]
1251
+ v32result_2 = m32[0] @ v32[0]
1252
+
1253
+ idx = 0
1254
+
1255
+ # multiply outputs by 2 so we've got something to backpropagate:
1256
+ for i in range(2):
1257
+ outcomponents[idx] = wptype(2) * v2result[i]
1258
+ idx = idx + 1
1259
+
1260
+ for i in range(3):
1261
+ outcomponents[idx] = wptype(2) * v3result[i]
1262
+ idx = idx + 1
1263
+
1264
+ for i in range(4):
1265
+ outcomponents[idx] = wptype(2) * v4result[i]
1266
+ idx = idx + 1
1267
+
1268
+ for i in range(5):
1269
+ outcomponents[idx] = wptype(2) * v5result[i]
1270
+ idx = idx + 1
1271
+
1272
+ for i in range(3):
1273
+ outcomponents[idx] = wptype(2) * v32result[i]
1274
+ idx = idx + 1
1275
+
1276
+ for i in range(2):
1277
+ outcomponents[idx] = wptype(2) * v2result_2[i]
1278
+ idx = idx + 1
1279
+
1280
+ for i in range(3):
1281
+ outcomponents[idx] = wptype(2) * v3result_2[i]
1282
+ idx = idx + 1
1283
+
1284
+ for i in range(4):
1285
+ outcomponents[idx] = wptype(2) * v4result_2[i]
1286
+ idx = idx + 1
1287
+
1288
+ for i in range(5):
1289
+ outcomponents[idx] = wptype(2) * v5result_2[i]
1290
+ idx = idx + 1
1291
+
1292
+ for i in range(3):
1293
+ outcomponents[idx] = wptype(2) * v32result_2[i]
1294
+ idx = idx + 1
1295
+
1296
+ kernel = getkernel(check_mat_vec_mul, suffix=dtype.__name__)
1297
+
1298
+ if register_kernels:
1299
+ return
1300
+
1301
+ v2 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1302
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1303
+ v4 = wp.array(randvals(rng, [1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
1304
+ v5 = wp.array(randvals(rng, [1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
1305
+ v32 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1306
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1307
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1308
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1309
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1310
+ m32 = wp.array(randvals(rng, [1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1311
+ outcomponents = wp.zeros(2 * (2 + 3 + 4 + 5 + 3), dtype=wptype, requires_grad=True, device=device)
1312
+
1313
+ wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1314
+
1315
+ assert_np_equal(outcomponents.numpy()[:2], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1316
+ assert_np_equal(outcomponents.numpy()[2:5], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1317
+ assert_np_equal(outcomponents.numpy()[5:9], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=5 * tol)
1318
+ assert_np_equal(outcomponents.numpy()[9:14], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=5 * tol)
1319
+ assert_np_equal(outcomponents.numpy()[14:17], 2 * np.matmul(m32.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1320
+ assert_np_equal(outcomponents.numpy()[17:19], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1321
+ assert_np_equal(outcomponents.numpy()[19:22], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1322
+ assert_np_equal(outcomponents.numpy()[22:26], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=5 * tol)
1323
+ assert_np_equal(outcomponents.numpy()[26:31], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=5 * tol)
1324
+ assert_np_equal(outcomponents.numpy()[31:34], 2 * np.matmul(m32.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1325
+
1326
+ if dtype in np_float_types:
1327
+ idx = 0
1328
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1329
+ for dim, invec, inmat in [(2, v2, m2), (3, v3, m3), (4, v4, m4), (5, v5, m5), (3, v32, m32)]:
1330
+ for i in range(dim):
1331
+ tape = wp.Tape()
1332
+ with tape:
1333
+ wp.launch(
1334
+ kernel,
1335
+ dim=1,
1336
+ inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32],
1337
+ outputs=[outcomponents],
1338
+ device=device,
1339
+ )
1340
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1341
+ tape.backward(loss=out)
1342
+
1343
+ assert_np_equal(tape.gradients[invec].numpy()[0], 2 * inmat.numpy()[0, i, :], tol=2 * tol)
1344
+ expectedresult = np.zeros(inmat.dtype._shape_, dtype=dtype)
1345
+ expectedresult[i, :] = 2 * invec.numpy()[0]
1346
+ assert_np_equal(tape.gradients[inmat].numpy()[0], expectedresult, tol=2 * tol)
1347
+
1348
+ tape.zero()
1349
+
1350
+ idx = idx + 1
1351
+
1352
+
1353
+ def test_vecmat_multiplication(test, device, dtype, register_kernels=False):
1354
+ rng = np.random.default_rng(123)
1355
+
1356
+ tol = {
1357
+ np.float16: 2.0e-2,
1358
+ np.float32: 5.0e-6,
1359
+ np.float64: 1.0e-8,
1360
+ }.get(dtype, 0)
1361
+
1362
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1363
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1364
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1365
+ mat23 = wp.types.matrix(shape=(2, 3), dtype=wptype)
1366
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1367
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1368
+
1369
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1370
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1371
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1372
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1373
+
1374
+ output_select_kernel = get_select_kernel(wptype)
1375
+
1376
+ def check_vec_mat_mul(
1377
+ v2: wp.array(dtype=vec2),
1378
+ v3: wp.array(dtype=vec3),
1379
+ v4: wp.array(dtype=vec4),
1380
+ v5: wp.array(dtype=vec5),
1381
+ v32: wp.array(dtype=vec2),
1382
+ m2: wp.array(dtype=mat22),
1383
+ m3: wp.array(dtype=mat33),
1384
+ m4: wp.array(dtype=mat44),
1385
+ m5: wp.array(dtype=mat55),
1386
+ m23: wp.array(dtype=mat23),
1387
+ outcomponents: wp.array(dtype=wptype),
1388
+ ):
1389
+ v2result = v2[0] * m2[0]
1390
+ v3result = v3[0] * m3[0]
1391
+ v4result = v4[0] * m4[0]
1392
+ v5result = v5[0] * m5[0]
1393
+ v32result = v32[0] * m23[0]
1394
+ v2result_2 = v2[0] @ m2[0]
1395
+ v3result_2 = v3[0] @ m3[0]
1396
+ v4result_2 = v4[0] @ m4[0]
1397
+ v5result_2 = v5[0] @ m5[0]
1398
+ v32result_2 = v32[0] @ m23[0]
1399
+
1400
+ idx = 0
1401
+
1402
+ # multiply outputs by 2 so we've got something to backpropagate:
1403
+ for i in range(2):
1404
+ outcomponents[idx] = wptype(2) * v2result[i]
1405
+ idx = idx + 1
1406
+
1407
+ for i in range(3):
1408
+ outcomponents[idx] = wptype(2) * v3result[i]
1409
+ idx = idx + 1
1410
+
1411
+ for i in range(4):
1412
+ outcomponents[idx] = wptype(2) * v4result[i]
1413
+ idx = idx + 1
1414
+
1415
+ for i in range(5):
1416
+ outcomponents[idx] = wptype(2) * v5result[i]
1417
+ idx = idx + 1
1418
+
1419
+ for i in range(3):
1420
+ outcomponents[idx] = wptype(2) * v32result[i]
1421
+ idx = idx + 1
1422
+
1423
+ for i in range(2):
1424
+ outcomponents[idx] = wptype(2) * v2result_2[i]
1425
+ idx = idx + 1
1426
+
1427
+ for i in range(3):
1428
+ outcomponents[idx] = wptype(2) * v3result_2[i]
1429
+ idx = idx + 1
1430
+
1431
+ for i in range(4):
1432
+ outcomponents[idx] = wptype(2) * v4result_2[i]
1433
+ idx = idx + 1
1434
+
1435
+ for i in range(5):
1436
+ outcomponents[idx] = wptype(2) * v5result_2[i]
1437
+ idx = idx + 1
1438
+
1439
+ for i in range(3):
1440
+ outcomponents[idx] = wptype(2) * v32result_2[i]
1441
+ idx = idx + 1
1442
+
1443
+ kernel = getkernel(check_vec_mat_mul, suffix=dtype.__name__)
1444
+
1445
+ if register_kernels:
1446
+ return
1447
+
1448
+ v2 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1449
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1450
+ v4 = wp.array(randvals(rng, [1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
1451
+ v5 = wp.array(randvals(rng, [1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
1452
+ v32 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1453
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1454
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1455
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1456
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1457
+ m23 = wp.array(randvals(rng, [1, 2, 3], dtype), dtype=mat23, requires_grad=True, device=device)
1458
+ outcomponents = wp.zeros(2 * (2 + 3 + 4 + 5 + 3), dtype=wptype, requires_grad=True, device=device)
1459
+
1460
+ wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m23], outputs=[outcomponents], device=device)
1461
+
1462
+ assert_np_equal(outcomponents.numpy()[:2], 2 * np.matmul(v2.numpy()[0], m2.numpy()[0]), tol=tol)
1463
+ assert_np_equal(outcomponents.numpy()[2:5], 2 * np.matmul(v3.numpy()[0], m3.numpy()[0]), tol=tol)
1464
+ assert_np_equal(outcomponents.numpy()[5:9], 2 * np.matmul(v4.numpy()[0], m4.numpy()[0]), tol=5 * tol)
1465
+ assert_np_equal(outcomponents.numpy()[9:14], 2 * np.matmul(v5.numpy()[0], m5.numpy()[0]), tol=5 * tol)
1466
+ assert_np_equal(outcomponents.numpy()[14:17], 2 * np.matmul(v32.numpy()[0], m23.numpy()[0]), tol=5 * tol)
1467
+ assert_np_equal(outcomponents.numpy()[17:19], 2 * np.matmul(v2.numpy()[0], m2.numpy()[0]), tol=tol)
1468
+ assert_np_equal(outcomponents.numpy()[19:22], 2 * np.matmul(v3.numpy()[0], m3.numpy()[0]), tol=tol)
1469
+ assert_np_equal(outcomponents.numpy()[22:26], 2 * np.matmul(v4.numpy()[0], m4.numpy()[0]), tol=5 * tol)
1470
+ assert_np_equal(outcomponents.numpy()[26:31], 2 * np.matmul(v5.numpy()[0], m5.numpy()[0]), tol=5 * tol)
1471
+ assert_np_equal(outcomponents.numpy()[31:34], 2 * np.matmul(v32.numpy()[0], m23.numpy()[0]), tol=5 * tol)
1472
+
1473
+ if dtype in np_float_types:
1474
+ idx = 0
1475
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1476
+ for dim, inmat, invec in [(2, m2, v2), (3, m3, v3), (4, m4, v4), (5, m5, v5), (3, m23, v32)]:
1477
+ for i in range(dim):
1478
+ tape = wp.Tape()
1479
+ with tape:
1480
+ wp.launch(
1481
+ kernel,
1482
+ dim=1,
1483
+ inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m23],
1484
+ outputs=[outcomponents],
1485
+ device=device,
1486
+ )
1487
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1488
+ tape.backward(loss=out)
1489
+
1490
+ assert_np_equal(tape.gradients[invec].numpy()[0], 2 * inmat.numpy()[0, :, i], tol=2 * tol)
1491
+ expectedresult = np.zeros(inmat.dtype._shape_, dtype=dtype)
1492
+ expectedresult[:, i] = 2 * invec.numpy()[0]
1493
+ assert_np_equal(tape.gradients[inmat].numpy()[0], expectedresult, tol=2 * tol)
1494
+
1495
+ tape.zero()
1496
+
1497
+ idx = idx + 1
1498
+
1499
+
1500
+ def test_matmat_multiplication(test, device, dtype, register_kernels=False):
1501
+ rng = np.random.default_rng(123)
1502
+
1503
+ tol = {
1504
+ np.float16: 2.0e-2,
1505
+ np.float32: 5.0e-6,
1506
+ np.float64: 1.0e-8,
1507
+ }.get(dtype, 0)
1508
+
1509
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1510
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1511
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1512
+ mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
1513
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1514
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1515
+
1516
+ output_select_kernel = get_select_kernel(wptype)
1517
+
1518
+ def check_mat_mat_mul(
1519
+ a2: wp.array(dtype=mat22),
1520
+ a3: wp.array(dtype=mat33),
1521
+ a4: wp.array(dtype=mat44),
1522
+ a5: wp.array(dtype=mat55),
1523
+ a32: wp.array(dtype=mat32),
1524
+ b2: wp.array(dtype=mat22),
1525
+ b3: wp.array(dtype=mat33),
1526
+ b4: wp.array(dtype=mat44),
1527
+ b5: wp.array(dtype=mat55),
1528
+ b32: wp.array(dtype=mat32),
1529
+ outcomponents: wp.array(dtype=wptype),
1530
+ ):
1531
+ c2result = b2[0] * a2[0]
1532
+ c3result = b3[0] * a3[0]
1533
+ c4result = b4[0] * a4[0]
1534
+ c5result = b5[0] * a5[0]
1535
+ c32result = b32[0] * a2[0]
1536
+ c32result2 = b3[0] * a32[0]
1537
+ c2result_2 = b2[0] @ a2[0]
1538
+ c3result_2 = b3[0] @ a3[0]
1539
+ c4result_2 = b4[0] @ a4[0]
1540
+ c5result_2 = b5[0] @ a5[0]
1541
+ c32result_2 = b32[0] @ a2[0]
1542
+ c32result2_2 = b3[0] @ a32[0]
1543
+
1544
+ # multiply outputs by 2 so we've got something to backpropagate:
1545
+ idx = 0
1546
+ for i in range(2):
1547
+ for j in range(2):
1548
+ outcomponents[idx] = wptype(2) * c2result[i, j]
1549
+ idx = idx + 1
1550
+
1551
+ for i in range(3):
1552
+ for j in range(3):
1553
+ outcomponents[idx] = wptype(2) * c3result[i, j]
1554
+ idx = idx + 1
1555
+
1556
+ for i in range(4):
1557
+ for j in range(4):
1558
+ outcomponents[idx] = wptype(2) * c4result[i, j]
1559
+ idx = idx + 1
1560
+
1561
+ for i in range(5):
1562
+ for j in range(5):
1563
+ outcomponents[idx] = wptype(2) * c5result[i, j]
1564
+ idx = idx + 1
1565
+
1566
+ for i in range(3):
1567
+ for j in range(2):
1568
+ outcomponents[idx] = wptype(2) * c32result[i, j]
1569
+ idx = idx + 1
1570
+
1571
+ for i in range(3):
1572
+ for j in range(2):
1573
+ outcomponents[idx] = wptype(2) * c32result2[i, j]
1574
+ idx = idx + 1
1575
+
1576
+ for i in range(2):
1577
+ for j in range(2):
1578
+ outcomponents[idx] = wptype(2) * c2result_2[i, j]
1579
+ idx = idx + 1
1580
+
1581
+ for i in range(3):
1582
+ for j in range(3):
1583
+ outcomponents[idx] = wptype(2) * c3result_2[i, j]
1584
+ idx = idx + 1
1585
+
1586
+ for i in range(4):
1587
+ for j in range(4):
1588
+ outcomponents[idx] = wptype(2) * c4result_2[i, j]
1589
+ idx = idx + 1
1590
+
1591
+ for i in range(5):
1592
+ for j in range(5):
1593
+ outcomponents[idx] = wptype(2) * c5result_2[i, j]
1594
+ idx = idx + 1
1595
+
1596
+ for i in range(3):
1597
+ for j in range(2):
1598
+ outcomponents[idx] = wptype(2) * c32result_2[i, j]
1599
+ idx = idx + 1
1600
+
1601
+ for i in range(3):
1602
+ for j in range(2):
1603
+ outcomponents[idx] = wptype(2) * c32result2_2[i, j]
1604
+ idx = idx + 1
1605
+
1606
+ kernel = getkernel(check_mat_mat_mul, suffix=dtype.__name__)
1607
+
1608
+ if register_kernels:
1609
+ return
1610
+
1611
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1612
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1613
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1614
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1615
+ v32 = wp.array(randvals(rng, [1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1616
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1617
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1618
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1619
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1620
+ m32 = wp.array(randvals(rng, [1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1621
+ outcomponents = wp.zeros(
1622
+ 2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2 + 3 * 2), dtype=wptype, requires_grad=True, device=device
1623
+ )
1624
+
1625
+ wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1626
+
1627
+ assert_np_equal(outcomponents.numpy()[:4], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1628
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1629
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=2 * tol)
1630
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=10 * tol)
1631
+ assert_np_equal(outcomponents.numpy()[54:60], 2 * np.matmul(m32.numpy()[0], v2.numpy()[0]), tol=5 * tol)
1632
+ assert_np_equal(outcomponents.numpy()[60:66], 2 * np.matmul(m3.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1633
+ assert_np_equal(outcomponents.numpy()[66:70], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1634
+ assert_np_equal(outcomponents.numpy()[70:79], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1635
+ assert_np_equal(outcomponents.numpy()[79:95], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=2 * tol)
1636
+ assert_np_equal(outcomponents.numpy()[95:120], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=10 * tol)
1637
+ assert_np_equal(outcomponents.numpy()[120:126], 2 * np.matmul(m32.numpy()[0], v2.numpy()[0]), tol=5 * tol)
1638
+ assert_np_equal(outcomponents.numpy()[126:132], 2 * np.matmul(m3.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1639
+
1640
+ if dtype in np_float_types:
1641
+ idx = 0
1642
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1643
+ for v, m in [(v2, m2), (v3, m3), (v4, m4), (v5, m5), (v2, m32), (v32, m3)]:
1644
+ rows, cols = m.dtype._shape_[0], v.dtype._shape_[1]
1645
+ for i in range(rows):
1646
+ for j in range(cols):
1647
+ tape = wp.Tape()
1648
+ with tape:
1649
+ wp.launch(
1650
+ kernel,
1651
+ dim=1,
1652
+ inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32],
1653
+ outputs=[outcomponents],
1654
+ device=device,
1655
+ )
1656
+ wp.launch(
1657
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1658
+ )
1659
+ tape.backward(loss=out)
1660
+
1661
+ expected = np.zeros(v.dtype._shape_, dtype=dtype)
1662
+ expected[:, j] = 2 * m.numpy()[0, i, :]
1663
+ assert_np_equal(tape.gradients[v].numpy()[0], expected, tol=10 * tol)
1664
+
1665
+ expected = np.zeros(m.dtype._shape_, dtype=dtype)
1666
+ expected[i, :] = 2 * v.numpy()[0, :, j]
1667
+ assert_np_equal(tape.gradients[m].numpy()[0], expected, tol=10 * tol)
1668
+
1669
+ tape.zero()
1670
+ idx = idx + 1
1671
+
1672
+
1673
+ def test_cw_multiplication(test, device, dtype, register_kernels=False):
1674
+ rng = np.random.default_rng(123)
1675
+
1676
+ tol = {
1677
+ np.float16: 5.0e-2,
1678
+ np.float32: 1.0e-6,
1679
+ np.float64: 1.0e-8,
1680
+ }.get(dtype, 0)
1681
+
1682
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1683
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1684
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1685
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1686
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1687
+
1688
+ output_select_kernel = get_select_kernel(wptype)
1689
+
1690
+ def check_mat_cw_mul(
1691
+ s2: wp.array(dtype=mat22),
1692
+ s3: wp.array(dtype=mat33),
1693
+ s4: wp.array(dtype=mat44),
1694
+ s5: wp.array(dtype=mat55),
1695
+ v2: wp.array(dtype=mat22),
1696
+ v3: wp.array(dtype=mat33),
1697
+ v4: wp.array(dtype=mat44),
1698
+ v5: wp.array(dtype=mat55),
1699
+ outcomponents: wp.array(dtype=wptype),
1700
+ ):
1701
+ v2result = wptype(2) * wp.cw_mul(v2[0], s2[0])
1702
+ v3result = wptype(2) * wp.cw_mul(v3[0], s3[0])
1703
+ v4result = wptype(2) * wp.cw_mul(v4[0], s4[0])
1704
+ v5result = wptype(2) * wp.cw_mul(v5[0], s5[0])
1705
+
1706
+ # multiply outputs by 2 so we've got something to backpropagate:
1707
+ idx = 0
1708
+ for i in range(2):
1709
+ for j in range(2):
1710
+ outcomponents[idx] = v2result[i, j]
1711
+ idx = idx + 1
1712
+
1713
+ for i in range(3):
1714
+ for j in range(3):
1715
+ outcomponents[idx] = v3result[i, j]
1716
+ idx = idx + 1
1717
+
1718
+ for i in range(4):
1719
+ for j in range(4):
1720
+ outcomponents[idx] = v4result[i, j]
1721
+ idx = idx + 1
1722
+
1723
+ for i in range(5):
1724
+ for j in range(5):
1725
+ outcomponents[idx] = v5result[i, j]
1726
+ idx = idx + 1
1727
+
1728
+ kernel = getkernel(check_mat_cw_mul, suffix=dtype.__name__)
1729
+
1730
+ if register_kernels:
1731
+ return
1732
+
1733
+ s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1734
+ s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1735
+ s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1736
+ s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1737
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1738
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1739
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1740
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1741
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
1742
+
1743
+ wp.launch(
1744
+ kernel,
1745
+ dim=1,
1746
+ inputs=[
1747
+ s2,
1748
+ s3,
1749
+ s4,
1750
+ s5,
1751
+ v2,
1752
+ v3,
1753
+ v4,
1754
+ v5,
1755
+ ],
1756
+ outputs=[outcomponents],
1757
+ device=device,
1758
+ )
1759
+
1760
+ assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() * s2.numpy()).reshape(-1), tol=50 * tol)
1761
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() * s3.numpy()).reshape(-1), tol=50 * tol)
1762
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() * s4.numpy()).reshape(-1), tol=50 * tol)
1763
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() * s5.numpy()).reshape(-1), tol=50 * tol)
1764
+
1765
+ if dtype in np_float_types:
1766
+ idx = 0
1767
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1768
+ for dim, in1, in2 in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
1769
+ for i in range(dim):
1770
+ for j in range(dim):
1771
+ tape = wp.Tape()
1772
+ with tape:
1773
+ wp.launch(
1774
+ kernel,
1775
+ dim=1,
1776
+ inputs=[
1777
+ s2,
1778
+ s3,
1779
+ s4,
1780
+ s5,
1781
+ v2,
1782
+ v3,
1783
+ v4,
1784
+ v5,
1785
+ ],
1786
+ outputs=[outcomponents],
1787
+ device=device,
1788
+ )
1789
+ wp.launch(
1790
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1791
+ )
1792
+ tape.backward(loss=out)
1793
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
1794
+ expectedresult[i, j] = 2 * in1.numpy()[0][i, j]
1795
+ assert_np_equal(tape.gradients[in2].numpy()[0], expectedresult, tol=5 * tol)
1796
+ expectedresult[i, j] = 2 * in2.numpy()[0][i, j]
1797
+ assert_np_equal(tape.gradients[in1].numpy()[0], expectedresult, tol=5 * tol)
1798
+ tape.zero()
1799
+
1800
+ idx = idx + 1
1801
+
1802
+
1803
+ def test_cw_division(test, device, dtype, register_kernels=False):
1804
+ rng = np.random.default_rng(123)
1805
+
1806
+ tol = {
1807
+ np.float16: 1.0e-2,
1808
+ np.float32: 1.0e-6,
1809
+ np.float64: 1.0e-8,
1810
+ }.get(dtype, 0)
1811
+
1812
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1813
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1814
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1815
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1816
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1817
+
1818
+ output_select_kernel = get_select_kernel(wptype)
1819
+
1820
+ def check_mat_cw_div(
1821
+ s2: wp.array(dtype=mat22),
1822
+ s3: wp.array(dtype=mat33),
1823
+ s4: wp.array(dtype=mat44),
1824
+ s5: wp.array(dtype=mat55),
1825
+ v2: wp.array(dtype=mat22),
1826
+ v3: wp.array(dtype=mat33),
1827
+ v4: wp.array(dtype=mat44),
1828
+ v5: wp.array(dtype=mat55),
1829
+ outcomponents: wp.array(dtype=wptype),
1830
+ ):
1831
+ v2result = wptype(2) * wp.cw_div(v2[0], s2[0])
1832
+ v3result = wptype(2) * wp.cw_div(v3[0], s3[0])
1833
+ v4result = wptype(2) * wp.cw_div(v4[0], s4[0])
1834
+ v5result = wptype(2) * wp.cw_div(v5[0], s5[0])
1835
+
1836
+ # multiply outputs by 2 so we've got something to backpropagate:
1837
+ idx = 0
1838
+ for i in range(2):
1839
+ for j in range(2):
1840
+ outcomponents[idx] = v2result[i, j]
1841
+ idx = idx + 1
1842
+
1843
+ for i in range(3):
1844
+ for j in range(3):
1845
+ outcomponents[idx] = v3result[i, j]
1846
+ idx = idx + 1
1847
+
1848
+ for i in range(4):
1849
+ for j in range(4):
1850
+ outcomponents[idx] = v4result[i, j]
1851
+ idx = idx + 1
1852
+
1853
+ for i in range(5):
1854
+ for j in range(5):
1855
+ outcomponents[idx] = v5result[i, j]
1856
+ idx = idx + 1
1857
+
1858
+ kernel = getkernel(check_mat_cw_div, suffix=dtype.__name__)
1859
+
1860
+ if register_kernels:
1861
+ return
1862
+
1863
+ s2 = randvals(rng, [1, 2, 2], dtype)
1864
+ s3 = randvals(rng, [1, 3, 3], dtype)
1865
+ s4 = randvals(rng, [1, 4, 4], dtype)
1866
+ s5 = randvals(rng, [1, 5, 5], dtype)
1867
+
1868
+ # set denominators to 1 if their magnitudes are small
1869
+ # to prevent divide by zero, or overflows if we're testing
1870
+ # float16:
1871
+ s2[np.abs(s2) < 1.0e-2] = 1
1872
+ s3[np.abs(s3) < 1.0e-2] = 1
1873
+ s4[np.abs(s4) < 1.0e-2] = 1
1874
+ s5[np.abs(s5) < 1.0e-2] = 1
1875
+
1876
+ s2 = wp.array(s2, dtype=mat22, requires_grad=True, device=device)
1877
+ s3 = wp.array(s3, dtype=mat33, requires_grad=True, device=device)
1878
+ s4 = wp.array(s4, dtype=mat44, requires_grad=True, device=device)
1879
+ s5 = wp.array(s5, dtype=mat55, requires_grad=True, device=device)
1880
+
1881
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1882
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1883
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1884
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1885
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
1886
+
1887
+ wp.launch(
1888
+ kernel,
1889
+ dim=1,
1890
+ inputs=[
1891
+ s2,
1892
+ s3,
1893
+ s4,
1894
+ s5,
1895
+ v2,
1896
+ v3,
1897
+ v4,
1898
+ v5,
1899
+ ],
1900
+ outputs=[outcomponents],
1901
+ device=device,
1902
+ )
1903
+
1904
+ if dtype in np_float_types:
1905
+ assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() / s2.numpy()).reshape(-1), tol=50 * tol)
1906
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() / s3.numpy()).reshape(-1), tol=50 * tol)
1907
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() / s4.numpy()).reshape(-1), tol=50 * tol)
1908
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() / s5.numpy()).reshape(-1), tol=50 * tol)
1909
+ else:
1910
+ assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() // s2.numpy()).reshape(-1), tol=50 * tol)
1911
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() // s3.numpy()).reshape(-1), tol=50 * tol)
1912
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() // s4.numpy()).reshape(-1), tol=50 * tol)
1913
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() // s5.numpy()).reshape(-1), tol=50 * tol)
1914
+
1915
+ if dtype in np_float_types:
1916
+ idx = 0
1917
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1918
+ for dim, s, v in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
1919
+ for i in range(dim):
1920
+ for j in range(dim):
1921
+ tape = wp.Tape()
1922
+ with tape:
1923
+ wp.launch(
1924
+ kernel,
1925
+ dim=1,
1926
+ inputs=[
1927
+ s2,
1928
+ s3,
1929
+ s4,
1930
+ s5,
1931
+ v2,
1932
+ v3,
1933
+ v4,
1934
+ v5,
1935
+ ],
1936
+ outputs=[outcomponents],
1937
+ device=device,
1938
+ )
1939
+ wp.launch(
1940
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1941
+ )
1942
+ tape.backward(loss=out)
1943
+
1944
+ # y = v/s
1945
+ # dy/dv = 1.0/s
1946
+ # dy/ds = -v/s^2
1947
+
1948
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
1949
+ expectedresult[i, j] = 2.0 / (s.numpy()[0, i, j])
1950
+ assert_np_equal(tape.gradients[v].numpy()[0], expectedresult, tol=50 * tol)
1951
+ expectedresult[i, j] = -2.0 * v.numpy()[0, i, j] / (s.numpy()[0, i, j] ** 2)
1952
+ assert_np_equal(
1953
+ tape.gradients[s].numpy()[0], expectedresult, tol=abs(outcomponents.numpy()[idx]) * 50 * tol
1954
+ )
1955
+ tape.zero()
1956
+
1957
+ idx = idx + 1
1958
+
1959
+
1960
+ def test_outer_product(test, device, dtype, register_kernels=False):
1961
+ rng = np.random.default_rng(123)
1962
+
1963
+ tol = {
1964
+ np.float16: 5.0e-3,
1965
+ np.float32: 1.0e-6,
1966
+ np.float64: 1.0e-8,
1967
+ }.get(dtype, 0)
1968
+
1969
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1970
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1971
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1972
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1973
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1974
+
1975
+ output_select_kernel = get_select_kernel(wptype)
1976
+
1977
+ def check_mat_outer_product(
1978
+ s2: wp.array(dtype=vec2),
1979
+ s3: wp.array(dtype=vec3),
1980
+ s4: wp.array(dtype=vec4),
1981
+ s5: wp.array(dtype=vec5),
1982
+ v2: wp.array(dtype=vec2),
1983
+ v3: wp.array(dtype=vec3),
1984
+ v4: wp.array(dtype=vec4),
1985
+ v5: wp.array(dtype=vec5),
1986
+ outcomponents: wp.array(dtype=wptype),
1987
+ ):
1988
+ m22result = wptype(2) * wp.outer(s2[0], v2[0])
1989
+ m33result = wptype(2) * wp.outer(s3[0], v3[0])
1990
+ m44result = wptype(2) * wp.outer(s4[0], v4[0])
1991
+ m55result = wptype(2) * wp.outer(s5[0], v5[0])
1992
+ m25result = wptype(2) * wp.outer(s2[0], v5[0])
1993
+
1994
+ # multiply outputs by 2 so we've got something to backpropagate:
1995
+ idx = 0
1996
+ for i in range(2):
1997
+ for j in range(2):
1998
+ outcomponents[idx] = m22result[i, j]
1999
+ idx = idx + 1
2000
+
2001
+ for i in range(3):
2002
+ for j in range(3):
2003
+ outcomponents[idx] = m33result[i, j]
2004
+ idx = idx + 1
2005
+
2006
+ for i in range(4):
2007
+ for j in range(4):
2008
+ outcomponents[idx] = m44result[i, j]
2009
+ idx = idx + 1
2010
+
2011
+ for i in range(5):
2012
+ for j in range(5):
2013
+ outcomponents[idx] = m55result[i, j]
2014
+ idx = idx + 1
2015
+
2016
+ for i in range(2):
2017
+ for j in range(5):
2018
+ outcomponents[idx] = m25result[i, j]
2019
+ idx = idx + 1
2020
+
2021
+ kernel = getkernel(check_mat_outer_product, suffix=dtype.__name__)
2022
+
2023
+ if register_kernels:
2024
+ return
2025
+
2026
+ s2 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
2027
+ s3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
2028
+ s4 = wp.array(randvals(rng, [1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
2029
+ s5 = wp.array(randvals(rng, [1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
2030
+ v2 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
2031
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
2032
+ v4 = wp.array(randvals(rng, [1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
2033
+ v5 = wp.array(randvals(rng, [1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
2034
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 2 * 5, dtype=wptype, requires_grad=True, device=device)
2035
+
2036
+ wp.launch(kernel, dim=1, inputs=[s2, s3, s4, s5, v2, v3, v4, v5], outputs=[outcomponents], device=device)
2037
+
2038
+ assert_np_equal(outcomponents.numpy()[:4], 2 * s2.numpy()[0, :, None] * v2.numpy()[0, None, :], tol=tol)
2039
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * s3.numpy()[0, :, None] * v3.numpy()[0, None, :], tol=10 * tol)
2040
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * s4.numpy()[0, :, None] * v4.numpy()[0, None, :], tol=10 * tol)
2041
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * s5.numpy()[0, :, None] * v5.numpy()[0, None, :], tol=10 * tol)
2042
+ assert_np_equal(outcomponents.numpy()[54:], 2 * s2.numpy()[0, :, None] * v5.numpy()[0, None, :], tol=10 * tol)
2043
+
2044
+ if dtype in np_float_types:
2045
+ idx = 0
2046
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2047
+ for s, v in [(s2, v2), (s3, v3), (s4, v4), (s5, v5), (s2, v5)]:
2048
+ rows = s.dtype._length_
2049
+ cols = v.dtype._length_
2050
+ for i in range(rows):
2051
+ for j in range(cols):
2052
+ tape = wp.Tape()
2053
+ with tape:
2054
+ wp.launch(
2055
+ kernel,
2056
+ dim=1,
2057
+ inputs=[
2058
+ s2,
2059
+ s3,
2060
+ s4,
2061
+ s5,
2062
+ v2,
2063
+ v3,
2064
+ v4,
2065
+ v5,
2066
+ ],
2067
+ outputs=[outcomponents],
2068
+ device=device,
2069
+ )
2070
+ wp.launch(
2071
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2072
+ )
2073
+ tape.backward(loss=out)
2074
+
2075
+ # this component's gonna be s_i * v_j, so its s gradient is gonna be nozero
2076
+ # at the ith component and its v gradient will be nonzero at the jth component:
2077
+
2078
+ expectedresult = np.zeros((rows), dtype=dtype)
2079
+ expectedresult[i] = 2 * v.numpy()[0, j]
2080
+ assert_np_equal(tape.gradients[s].numpy()[0], expectedresult, tol=10 * tol)
2081
+
2082
+ expectedresult = np.zeros((cols), dtype=dtype)
2083
+ expectedresult[j] = 2 * s.numpy()[0, i]
2084
+ assert_np_equal(tape.gradients[v].numpy()[0], expectedresult, tol=10 * tol)
2085
+ tape.zero()
2086
+
2087
+ idx = idx + 1
2088
+
2089
+
2090
+ def test_transpose(test, device, dtype, register_kernels=False):
2091
+ rng = np.random.default_rng(123)
2092
+
2093
+ tol = {
2094
+ np.float16: 1.0e-2,
2095
+ np.float32: 1.0e-6,
2096
+ np.float64: 1.0e-8,
2097
+ }.get(dtype, 0)
2098
+
2099
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2100
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2101
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2102
+ mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
2103
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2104
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2105
+
2106
+ output_select_kernel = get_select_kernel(wptype)
2107
+
2108
+ def check_mat_transpose(
2109
+ m2: wp.array(dtype=mat22),
2110
+ m3: wp.array(dtype=mat33),
2111
+ m4: wp.array(dtype=mat44),
2112
+ m5: wp.array(dtype=mat55),
2113
+ m32: wp.array(dtype=mat32),
2114
+ outcomponents: wp.array(dtype=wptype),
2115
+ ):
2116
+ # multiply outputs by 2 so we've got something to backpropagate:
2117
+ mat2 = wptype(2) * wp.transpose(m2[0])
2118
+ mat3 = wptype(2) * wp.transpose(m3[0])
2119
+ mat4 = wptype(2) * wp.transpose(m4[0])
2120
+ mat5 = wptype(2) * wp.transpose(m5[0])
2121
+ mat32 = wptype(2) * wp.transpose(m32[0])
2122
+
2123
+ idx = 0
2124
+ for i in range(2):
2125
+ for j in range(2):
2126
+ outcomponents[idx] = mat2[i, j]
2127
+ idx = idx + 1
2128
+
2129
+ for i in range(3):
2130
+ for j in range(3):
2131
+ outcomponents[idx] = mat3[i, j]
2132
+ idx = idx + 1
2133
+
2134
+ for i in range(4):
2135
+ for j in range(4):
2136
+ outcomponents[idx] = mat4[i, j]
2137
+ idx = idx + 1
2138
+
2139
+ for i in range(5):
2140
+ for j in range(5):
2141
+ outcomponents[idx] = mat5[i, j]
2142
+ idx = idx + 1
2143
+
2144
+ for i in range(2):
2145
+ for j in range(3):
2146
+ outcomponents[idx] = mat32[i, j]
2147
+ idx = idx + 1
2148
+
2149
+ kernel = getkernel(check_mat_transpose, suffix=dtype.__name__)
2150
+
2151
+ if register_kernels:
2152
+ return
2153
+
2154
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2155
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2156
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2157
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2158
+ m32 = wp.array(randvals(rng, [1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
2159
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 2 * 3, dtype=wptype, requires_grad=True, device=device)
2160
+
2161
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
2162
+
2163
+ assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy()[0].T.reshape(-1), tol=tol)
2164
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy()[0].T.reshape(-1), tol=tol)
2165
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy()[0].T.reshape(-1), tol=tol)
2166
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy()[0].T.reshape(-1), tol=tol)
2167
+ assert_np_equal(outcomponents.numpy()[54:], 2 * m32.numpy()[0].T.reshape(-1), tol=tol)
2168
+
2169
+ if dtype in np_float_types:
2170
+ idx = 0
2171
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2172
+ for input in [m2, m3, m4, m5]:
2173
+ for i in range(input.dtype._shape_[0]):
2174
+ for j in range(input.dtype._shape_[1]):
2175
+ tape = wp.Tape()
2176
+ with tape:
2177
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
2178
+ wp.launch(
2179
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2180
+ )
2181
+ tape.backward(loss=out)
2182
+ expectedresult = np.zeros((input.dtype._shape_[1], input.dtype._shape_[0]), dtype=dtype)
2183
+ expectedresult[j, i] = 2
2184
+ assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
2185
+ tape.zero()
2186
+ idx = idx + 1
2187
+
2188
+
2189
+ def test_scalar_division(test, device, dtype, register_kernels=False):
2190
+ rng = np.random.default_rng(123)
2191
+
2192
+ tol = {
2193
+ np.float16: 1.0e-2,
2194
+ np.float32: 1.0e-6,
2195
+ np.float64: 1.0e-8,
2196
+ }.get(dtype, 0)
2197
+
2198
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2199
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2200
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2201
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2202
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2203
+
2204
+ output_select_kernel = get_select_kernel(wptype)
2205
+
2206
+ def check_mat_scalar_div(
2207
+ s: wp.array(dtype=wptype),
2208
+ m2: wp.array(dtype=mat22),
2209
+ m3: wp.array(dtype=mat33),
2210
+ m4: wp.array(dtype=mat44),
2211
+ m5: wp.array(dtype=mat55),
2212
+ outcomponents: wp.array(dtype=wptype),
2213
+ ):
2214
+ m2result = m2[0] / s[0]
2215
+ m3result = m3[0] / s[0]
2216
+ m4result = m4[0] / s[0]
2217
+ m5result = m5[0] / s[0]
2218
+
2219
+ # multiply outputs by 2 so we've got something to backpropagate:
2220
+ idx = 0
2221
+ for i in range(2):
2222
+ for j in range(2):
2223
+ outcomponents[idx] = wptype(2) * m2result[i, j]
2224
+ idx = idx + 1
2225
+
2226
+ for i in range(3):
2227
+ for j in range(3):
2228
+ outcomponents[idx] = wptype(2) * m3result[i, j]
2229
+ idx = idx + 1
2230
+
2231
+ for i in range(4):
2232
+ for j in range(4):
2233
+ outcomponents[idx] = wptype(2) * m4result[i, j]
2234
+ idx = idx + 1
2235
+
2236
+ for i in range(5):
2237
+ for j in range(5):
2238
+ outcomponents[idx] = wptype(2) * m5result[i, j]
2239
+ idx = idx + 1
2240
+
2241
+ kernel = getkernel(check_mat_scalar_div, suffix=dtype.__name__)
2242
+
2243
+ if register_kernels:
2244
+ return
2245
+
2246
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
2247
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2248
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2249
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2250
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2251
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
2252
+
2253
+ wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents], device=device)
2254
+
2255
+ sval = s.numpy()[0]
2256
+ if dtype in np_float_types:
2257
+ assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy().reshape(-1) / sval, tol=tol)
2258
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy().reshape(-1) / sval, tol=10 * tol)
2259
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy().reshape(-1) / sval, tol=10 * tol)
2260
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy().reshape(-1) / sval, tol=10 * tol)
2261
+ else:
2262
+ assert_np_equal(outcomponents.numpy()[:4], 2 * (m2.numpy().reshape(-1) // sval), tol=tol)
2263
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * (m3.numpy().reshape(-1) // sval), tol=10 * tol)
2264
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * (m4.numpy().reshape(-1) // sval), tol=10 * tol)
2265
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * (m5.numpy().reshape(-1) // sval), tol=10 * tol)
2266
+
2267
+ if dtype in np_float_types:
2268
+ idx = 0
2269
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2270
+ for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
2271
+ for i in range(dim):
2272
+ for j in range(dim):
2273
+ tape = wp.Tape()
2274
+ with tape:
2275
+ wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents], device=device)
2276
+ wp.launch(
2277
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2278
+ )
2279
+ tape.backward(loss=out)
2280
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
2281
+ expectedresult[i, j] = 2.0 / sval
2282
+ assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
2283
+ assert_np_equal(
2284
+ tape.gradients[s].numpy()[0], -2 * input.numpy()[0, i, j] / (sval * sval), tol=10 * tol
2285
+ )
2286
+ tape.zero()
2287
+
2288
+ idx = idx + 1
2289
+
2290
+
2291
+ def test_addition(test, device, dtype, register_kernels=False):
2292
+ rng = np.random.default_rng(123)
2293
+
2294
+ tol = {
2295
+ np.float16: 2.0e-2,
2296
+ np.float32: 5.0e-6,
2297
+ np.float64: 1.0e-8,
2298
+ }.get(dtype, 0)
2299
+
2300
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2301
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2302
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2303
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2304
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2305
+
2306
+ output_select_kernel = get_select_kernel(wptype)
2307
+
2308
+ def check_mat_add(
2309
+ s2: wp.array(dtype=mat22),
2310
+ s3: wp.array(dtype=mat33),
2311
+ s4: wp.array(dtype=mat44),
2312
+ s5: wp.array(dtype=mat55),
2313
+ v2: wp.array(dtype=mat22),
2314
+ v3: wp.array(dtype=mat33),
2315
+ v4: wp.array(dtype=mat44),
2316
+ v5: wp.array(dtype=mat55),
2317
+ outcomponents: wp.array(dtype=wptype),
2318
+ ):
2319
+ v2result = v2[0] + s2[0]
2320
+ v3result = v3[0] + s3[0]
2321
+ v4result = v4[0] + s4[0]
2322
+ v5result = v5[0] + s5[0]
2323
+
2324
+ # multiply outputs by 2 so we've got something to backpropagate:
2325
+ idx = 0
2326
+ for i in range(2):
2327
+ for j in range(2):
2328
+ outcomponents[idx] = wptype(2) * v2result[i, j]
2329
+ idx = idx + 1
2330
+
2331
+ for i in range(3):
2332
+ for j in range(3):
2333
+ outcomponents[idx] = wptype(2) * v3result[i, j]
2334
+ idx = idx + 1
2335
+
2336
+ for i in range(4):
2337
+ for j in range(4):
2338
+ outcomponents[idx] = wptype(2) * v4result[i, j]
2339
+ idx = idx + 1
2340
+
2341
+ for i in range(5):
2342
+ for j in range(5):
2343
+ outcomponents[idx] = wptype(2) * v5result[i, j]
2344
+ idx = idx + 1
2345
+
2346
+ kernel = getkernel(check_mat_add, suffix=dtype.__name__)
2347
+
2348
+ if register_kernels:
2349
+ return
2350
+
2351
+ s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2352
+ s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2353
+ s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2354
+ s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2355
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2356
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2357
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2358
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2359
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
2360
+
2361
+ wp.launch(
2362
+ kernel,
2363
+ dim=1,
2364
+ inputs=[
2365
+ s2,
2366
+ s3,
2367
+ s4,
2368
+ s5,
2369
+ v2,
2370
+ v3,
2371
+ v4,
2372
+ v5,
2373
+ ],
2374
+ outputs=[outcomponents],
2375
+ device=device,
2376
+ )
2377
+
2378
+ assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() + s2.numpy()).reshape(-1), tol=tol)
2379
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() + s3.numpy()).reshape(-1), tol=tol)
2380
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() + s4.numpy()).reshape(-1), tol=tol)
2381
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() + s5.numpy()).reshape(-1), tol=tol)
2382
+
2383
+ if dtype in np_float_types:
2384
+ idx = 0
2385
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2386
+ for dim, in1, in2 in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
2387
+ for i in range(dim):
2388
+ for j in range(dim):
2389
+ tape = wp.Tape()
2390
+ with tape:
2391
+ wp.launch(
2392
+ kernel,
2393
+ dim=1,
2394
+ inputs=[
2395
+ s2,
2396
+ s3,
2397
+ s4,
2398
+ s5,
2399
+ v2,
2400
+ v3,
2401
+ v4,
2402
+ v5,
2403
+ ],
2404
+ outputs=[outcomponents],
2405
+ device=device,
2406
+ )
2407
+ wp.launch(
2408
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2409
+ )
2410
+ tape.backward(loss=out)
2411
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
2412
+ expectedresult[i, j] = 2
2413
+ assert_np_equal(tape.gradients[in2].numpy()[0], expectedresult, tol=10 * tol)
2414
+ expectedresult[i, j] = 2
2415
+ assert_np_equal(tape.gradients[in1].numpy()[0], expectedresult, tol=10 * tol)
2416
+ tape.zero()
2417
+
2418
+ idx = idx + 1
2419
+
2420
+
2421
+ def test_ddot(test, device, dtype, register_kernels=False):
2422
+ rng = np.random.default_rng(123)
2423
+
2424
+ tol = {
2425
+ np.float16: 5.0e-3,
2426
+ np.float32: 1.0e-6,
2427
+ np.float64: 1.0e-8,
2428
+ }.get(dtype, 0)
2429
+
2430
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2431
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2432
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2433
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2434
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2435
+
2436
+ def check_mat_dot(
2437
+ s2: wp.array(dtype=mat22),
2438
+ s3: wp.array(dtype=mat33),
2439
+ s4: wp.array(dtype=mat44),
2440
+ s5: wp.array(dtype=mat55),
2441
+ v2: wp.array(dtype=mat22),
2442
+ v3: wp.array(dtype=mat33),
2443
+ v4: wp.array(dtype=mat44),
2444
+ v5: wp.array(dtype=mat55),
2445
+ dot2: wp.array(dtype=wptype),
2446
+ dot3: wp.array(dtype=wptype),
2447
+ dot4: wp.array(dtype=wptype),
2448
+ dot5: wp.array(dtype=wptype),
2449
+ ):
2450
+ # multiply outputs by 2 so we've got something to backpropagate:
2451
+ dot2[0] = wptype(2) * wp.ddot(v2[0], s2[0])
2452
+ dot3[0] = wptype(2) * wp.ddot(v3[0], s3[0])
2453
+ dot4[0] = wptype(2) * wp.ddot(v4[0], s4[0])
2454
+ dot5[0] = wptype(2) * wp.ddot(v5[0], s5[0])
2455
+
2456
+ kernel = getkernel(check_mat_dot, suffix=dtype.__name__)
2457
+
2458
+ if register_kernels:
2459
+ return
2460
+
2461
+ s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2462
+ s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2463
+ s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2464
+ s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2465
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2466
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2467
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2468
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2469
+ dot2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2470
+ dot3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2471
+ dot4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2472
+ dot5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2473
+
2474
+ tape = wp.Tape()
2475
+ with tape:
2476
+ wp.launch(
2477
+ kernel,
2478
+ dim=1,
2479
+ inputs=[
2480
+ s2,
2481
+ s3,
2482
+ s4,
2483
+ s5,
2484
+ v2,
2485
+ v3,
2486
+ v4,
2487
+ v5,
2488
+ ],
2489
+ outputs=[dot2, dot3, dot4, dot5],
2490
+ device=device,
2491
+ )
2492
+
2493
+ assert_np_equal(dot2.numpy()[0], 2 * (v2.numpy() * s2.numpy()).sum(), tol=10 * tol)
2494
+ assert_np_equal(dot3.numpy()[0], 2 * (v3.numpy() * s3.numpy()).sum(), tol=10 * tol)
2495
+ assert_np_equal(dot4.numpy()[0], 2 * (v4.numpy() * s4.numpy()).sum(), tol=50 * tol)
2496
+ assert_np_equal(dot5.numpy()[0], 2 * (v5.numpy() * s5.numpy()).sum(), tol=200 * tol)
2497
+
2498
+ if dtype in np_float_types:
2499
+ tape.backward(loss=dot2)
2500
+ sgrads = tape.gradients[s2].numpy()[0]
2501
+ expected_grads = 2.0 * v2.numpy()[0]
2502
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2503
+
2504
+ vgrads = tape.gradients[v2].numpy()[0]
2505
+ expected_grads = 2.0 * s2.numpy()[0]
2506
+ assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2507
+
2508
+ tape.zero()
2509
+
2510
+ tape.backward(loss=dot3)
2511
+ sgrads = tape.gradients[s3].numpy()[0]
2512
+ expected_grads = 2.0 * v3.numpy()[0]
2513
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2514
+
2515
+ vgrads = tape.gradients[v3].numpy()[0]
2516
+ expected_grads = 2.0 * s3.numpy()[0]
2517
+ assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2518
+
2519
+ tape.zero()
2520
+
2521
+ tape.backward(loss=dot4)
2522
+ sgrads = tape.gradients[s4].numpy()[0]
2523
+ expected_grads = 2.0 * v4.numpy()[0]
2524
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2525
+
2526
+ vgrads = tape.gradients[v4].numpy()[0]
2527
+ expected_grads = 2.0 * s4.numpy()[0]
2528
+ assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2529
+
2530
+ tape.zero()
2531
+
2532
+ tape.backward(loss=dot5)
2533
+ sgrads = tape.gradients[s5].numpy()[0]
2534
+ expected_grads = 2.0 * v5.numpy()[0]
2535
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2536
+
2537
+ vgrads = tape.gradients[v5].numpy()[0]
2538
+ expected_grads = 2.0 * s5.numpy()[0]
2539
+ assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2540
+
2541
+ tape.zero()
2542
+
2543
+
2544
+ def test_trace(test, device, dtype, register_kernels=False):
2545
+ rng = np.random.default_rng(123)
2546
+
2547
+ tol = {
2548
+ np.float16: 1.0e-3,
2549
+ np.float32: 1.0e-6,
2550
+ np.float64: 1.0e-8,
2551
+ }.get(dtype, 0)
2552
+
2553
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2554
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2555
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2556
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2557
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2558
+
2559
+ def check_mat_trace(
2560
+ v2: wp.array(dtype=mat22),
2561
+ v3: wp.array(dtype=mat33),
2562
+ v4: wp.array(dtype=mat44),
2563
+ v5: wp.array(dtype=mat55),
2564
+ tr2: wp.array(dtype=wptype),
2565
+ tr3: wp.array(dtype=wptype),
2566
+ tr4: wp.array(dtype=wptype),
2567
+ tr5: wp.array(dtype=wptype),
2568
+ ):
2569
+ # multiply outputs by 2 so we've got something to backpropagate:
2570
+ tr2[0] = wptype(2) * wp.trace(v2[0])
2571
+ tr3[0] = wptype(2) * wp.trace(v3[0])
2572
+ tr4[0] = wptype(2) * wp.trace(v4[0])
2573
+ tr5[0] = wptype(2) * wp.trace(v5[0])
2574
+
2575
+ kernel = getkernel(check_mat_trace, suffix=dtype.__name__)
2576
+
2577
+ if register_kernels:
2578
+ return
2579
+
2580
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2581
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2582
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2583
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2584
+ tr2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2585
+ tr3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2586
+ tr4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2587
+ tr5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2588
+
2589
+ tape = wp.Tape()
2590
+ with tape:
2591
+ wp.launch(
2592
+ kernel,
2593
+ dim=1,
2594
+ inputs=[
2595
+ v2,
2596
+ v3,
2597
+ v4,
2598
+ v5,
2599
+ ],
2600
+ outputs=[
2601
+ tr2,
2602
+ tr3,
2603
+ tr4,
2604
+ tr5,
2605
+ ],
2606
+ device=device,
2607
+ )
2608
+
2609
+ assert_np_equal(tr2.numpy()[0], 2 * np.trace(v2.numpy()[0]), tol=10 * tol)
2610
+ assert_np_equal(tr3.numpy()[0], 2 * np.trace(v3.numpy()[0]), tol=10 * tol)
2611
+ assert_np_equal(tr4.numpy()[0], 2 * np.trace(v4.numpy()[0]), tol=200 * tol)
2612
+ assert_np_equal(tr4.numpy()[0], 2 * np.trace(v4.numpy()[0]), tol=200 * tol)
2613
+
2614
+ if dtype in np_float_types:
2615
+ tape.backward(loss=tr2)
2616
+ vgrads = tape.gradients[v2].numpy()[0]
2617
+ assert_np_equal(vgrads, 2.0 * np.eye(2), tol=10 * tol)
2618
+ tape.zero()
2619
+
2620
+ tape.backward(loss=tr3)
2621
+ vgrads = tape.gradients[v3].numpy()[0]
2622
+ assert_np_equal(vgrads, 2.0 * np.eye(3), tol=10 * tol)
2623
+ tape.zero()
2624
+
2625
+ tape.backward(loss=tr4)
2626
+ vgrads = tape.gradients[v4].numpy()[0]
2627
+ assert_np_equal(vgrads, 2.0 * np.eye(4), tol=10 * tol)
2628
+ tape.zero()
2629
+
2630
+ tape.backward(loss=tr5)
2631
+ vgrads = tape.gradients[v5].numpy()[0]
2632
+ assert_np_equal(vgrads, 2.0 * np.eye(5), tol=10 * tol)
2633
+ tape.zero()
2634
+
2635
+
2636
+ def test_diag(test, device, dtype, register_kernels=False):
2637
+ rng = np.random.default_rng(123)
2638
+
2639
+ tol = {
2640
+ np.float16: 1.0e-3,
2641
+ np.float32: 1.0e-6,
2642
+ np.float64: 1.0e-8,
2643
+ }.get(dtype, 0)
2644
+
2645
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2646
+ vec5 = wp.types.vector(length=5, dtype=wptype)
2647
+
2648
+ output_select_kernel = get_select_kernel(wptype)
2649
+
2650
+ def check_mat_diag(
2651
+ s5: wp.array(dtype=vec5),
2652
+ outcomponents: wp.array(dtype=wptype),
2653
+ ):
2654
+ # multiply outputs by 2 so we've got something to backpropagate:
2655
+ m55result = wptype(2) * wp.diag(s5[0])
2656
+
2657
+ idx = 0
2658
+ for i in range(5):
2659
+ for j in range(5):
2660
+ outcomponents[idx] = m55result[i, j]
2661
+ idx = idx + 1
2662
+
2663
+ kernel = getkernel(check_mat_diag, suffix=dtype.__name__)
2664
+
2665
+ if register_kernels:
2666
+ return
2667
+
2668
+ s5 = wp.array(randvals(rng, [1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
2669
+ outcomponents = wp.zeros(5 * 5, dtype=wptype, requires_grad=True, device=device)
2670
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2671
+
2672
+ wp.launch(kernel, dim=1, inputs=[s5], outputs=[outcomponents], device=device)
2673
+
2674
+ assert_np_equal(outcomponents.numpy(), 2 * np.diag(s5.numpy()[0]), tol=tol)
2675
+
2676
+ if dtype in np_float_types:
2677
+ idx = 0
2678
+ for i in range(5):
2679
+ for j in range(5):
2680
+ tape = wp.Tape()
2681
+ with tape:
2682
+ wp.launch(kernel, dim=1, inputs=[s5], outputs=[outcomponents], device=device)
2683
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
2684
+ tape.backward(loss=out)
2685
+ expectedresult = np.zeros(5, dtype=dtype)
2686
+ if i == j:
2687
+ expectedresult[i] = 2
2688
+ assert_np_equal(tape.gradients[s5].numpy()[0], expectedresult, tol=10 * tol)
2689
+ tape.zero()
2690
+
2691
+ idx = idx + 1
2692
+
2693
+
2694
+ def test_equivalent_types(test, device, dtype, register_kernels=False):
2695
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2696
+
2697
+ # matrix types
2698
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2699
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2700
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2701
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2702
+
2703
+ # matrix types equivalent to the above
2704
+ mat22_equiv = wp.types.matrix(shape=(2, 2), dtype=wptype)
2705
+ mat33_equiv = wp.types.matrix(shape=(3, 3), dtype=wptype)
2706
+ mat44_equiv = wp.types.matrix(shape=(4, 4), dtype=wptype)
2707
+ mat55_equiv = wp.types.matrix(shape=(5, 5), dtype=wptype)
2708
+
2709
+ # declare kernel with original types
2710
+ def check_equivalence(
2711
+ m2: mat22,
2712
+ m3: mat33,
2713
+ m4: mat44,
2714
+ m5: mat55,
2715
+ ):
2716
+ wp.expect_eq(m2, mat22(wptype(42)))
2717
+ wp.expect_eq(m3, mat33(wptype(43)))
2718
+ wp.expect_eq(m4, mat44(wptype(44)))
2719
+ wp.expect_eq(m5, mat55(wptype(45)))
2720
+
2721
+ wp.expect_eq(m2, mat22_equiv(wptype(42)))
2722
+ wp.expect_eq(m3, mat33_equiv(wptype(43)))
2723
+ wp.expect_eq(m4, mat44_equiv(wptype(44)))
2724
+ wp.expect_eq(m5, mat55_equiv(wptype(45)))
2725
+
2726
+ kernel = getkernel(check_equivalence, suffix=dtype.__name__)
2727
+
2728
+ if register_kernels:
2729
+ return
2730
+
2731
+ # call kernel with equivalent types
2732
+ m2 = mat22_equiv(42)
2733
+ m3 = mat33_equiv(43)
2734
+ m4 = mat44_equiv(44)
2735
+ m5 = mat55_equiv(45)
2736
+
2737
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], device=device)
2738
+
2739
+
2740
+ def test_conversions(test, device, dtype, register_kernels=False):
2741
+ def check_matrices_equal(
2742
+ m0: wp.mat22,
2743
+ m1: wp.mat22,
2744
+ m2: wp.mat22,
2745
+ m3: wp.mat22,
2746
+ m4: wp.mat22,
2747
+ m5: wp.mat22,
2748
+ m6: wp.mat22,
2749
+ ):
2750
+ wp.expect_eq(m1, m0)
2751
+ wp.expect_eq(m2, m0)
2752
+ wp.expect_eq(m3, m0)
2753
+ wp.expect_eq(m4, m0)
2754
+ wp.expect_eq(m5, m0)
2755
+ wp.expect_eq(m6, m0)
2756
+
2757
+ kernel = getkernel(check_matrices_equal, suffix=dtype.__name__)
2758
+
2759
+ if register_kernels:
2760
+ return
2761
+
2762
+ m0 = wp.mat22(1, 2, 3, 4)
2763
+
2764
+ # test explicit conversions - constructing matrices from different containers
2765
+ m1 = wp.mat22(((1, 2), (3, 4))) # nested tuples
2766
+ m2 = wp.mat22([[1, 2], [3, 4]]) # nested lists
2767
+ m3 = wp.mat22(np.array([[1, 2], [3, 4]], dtype=dtype)) # 2d array
2768
+ m4 = wp.mat22((1, 2, 3, 4)) # flat tuple
2769
+ m5 = wp.mat22([1, 2, 3, 4]) # flat list
2770
+ m6 = wp.mat22(np.array([1, 2, 3, 4], dtype=dtype)) # 1d array
2771
+
2772
+ wp.launch(kernel, dim=1, inputs=[m0, m1, m2, m3, m4, m5, m6], device=device)
2773
+
2774
+ # test implicit conversions - passing different containers as matrices to wp.launch()
2775
+ m1 = ((1, 2), (3, 4)) # nested tuples
2776
+ m2 = [[1, 2], [3, 4]] # nested lists
2777
+ m3 = np.array([[1, 2], [3, 4]], dtype=dtype) # 2d array
2778
+ m4 = (1, 2, 3, 4) # flat tuple
2779
+ m5 = [1, 2, 3, 4] # flat list
2780
+ m6 = np.array([1, 2, 3, 4], dtype=dtype) # 1d array
2781
+
2782
+ wp.launch(kernel, dim=1, inputs=[m0, m1, m2, m3, m4, m5, m6], device=device)
2783
+
2784
+
2785
+ devices = get_test_devices()
2786
+
2787
+
2788
+ class TestMatScalarOps(unittest.TestCase):
2789
+ pass
2790
+
2791
+
2792
+ for dtype in np_scalar_types:
2793
+ add_function_test(TestMatScalarOps, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
2794
+ add_function_test(TestMatScalarOps, f"test_components_{dtype.__name__}", test_components, devices=None, dtype=dtype)
2795
+ add_function_test_register_kernel(
2796
+ TestMatScalarOps, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
2797
+ )
2798
+ add_function_test_register_kernel(
2799
+ TestMatScalarOps,
2800
+ f"test_anon_type_instance_{dtype.__name__}",
2801
+ test_anon_type_instance,
2802
+ devices=devices,
2803
+ dtype=dtype,
2804
+ )
2805
+ add_function_test_register_kernel(
2806
+ TestMatScalarOps, f"test_identity_{dtype.__name__}", test_identity, devices=devices, dtype=dtype
2807
+ )
2808
+ add_function_test_register_kernel(
2809
+ TestMatScalarOps, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2810
+ )
2811
+ add_function_test_register_kernel(
2812
+ TestMatScalarOps, f"test_equality_{dtype.__name__}", test_equality, devices=devices, dtype=dtype
2813
+ )
2814
+ add_function_test_register_kernel(
2815
+ TestMatScalarOps,
2816
+ f"test_scalar_multiplication_{dtype.__name__}",
2817
+ test_scalar_multiplication,
2818
+ devices=devices,
2819
+ dtype=dtype,
2820
+ )
2821
+ add_function_test_register_kernel(
2822
+ TestMatScalarOps,
2823
+ f"test_matvec_multiplication_{dtype.__name__}",
2824
+ test_matvec_multiplication,
2825
+ devices=devices,
2826
+ dtype=dtype,
2827
+ )
2828
+ add_function_test_register_kernel(
2829
+ TestMatScalarOps,
2830
+ f"test_vecmat_multiplication_{dtype.__name__}",
2831
+ test_vecmat_multiplication,
2832
+ devices=devices,
2833
+ dtype=dtype,
2834
+ )
2835
+ add_function_test_register_kernel(
2836
+ TestMatScalarOps,
2837
+ f"test_matmat_multiplication_{dtype.__name__}",
2838
+ test_matmat_multiplication,
2839
+ devices=devices,
2840
+ dtype=dtype,
2841
+ )
2842
+ add_function_test_register_kernel(
2843
+ TestMatScalarOps,
2844
+ f"test_cw_multiplication_{dtype.__name__}",
2845
+ test_cw_multiplication,
2846
+ devices=devices,
2847
+ dtype=dtype,
2848
+ )
2849
+ add_function_test_register_kernel(
2850
+ TestMatScalarOps, f"test_cw_division_{dtype.__name__}", test_cw_division, devices=devices, dtype=dtype
2851
+ )
2852
+ add_function_test_register_kernel(
2853
+ TestMatScalarOps, f"test_outer_product_{dtype.__name__}", test_outer_product, devices=devices, dtype=dtype
2854
+ )
2855
+ add_function_test_register_kernel(
2856
+ TestMatScalarOps, f"test_transpose_{dtype.__name__}", test_transpose, devices=devices, dtype=dtype
2857
+ )
2858
+ add_function_test_register_kernel(
2859
+ TestMatScalarOps, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2860
+ )
2861
+ add_function_test_register_kernel(
2862
+ TestMatScalarOps, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2863
+ )
2864
+ add_function_test_register_kernel(
2865
+ TestMatScalarOps, f"test_ddot_{dtype.__name__}", test_ddot, devices=devices, dtype=dtype
2866
+ )
2867
+ add_function_test_register_kernel(
2868
+ TestMatScalarOps, f"test_trace_{dtype.__name__}", test_trace, devices=devices, dtype=dtype
2869
+ )
2870
+ add_function_test_register_kernel(
2871
+ TestMatScalarOps, f"test_diag_{dtype.__name__}", test_diag, devices=devices, dtype=dtype
2872
+ )
2873
+ add_function_test_register_kernel(
2874
+ TestMatScalarOps, f"test_get_diag_{dtype.__name__}", test_diag, devices=devices, dtype=dtype
2875
+ )
2876
+ add_function_test_register_kernel(
2877
+ TestMatScalarOps, f"test_equivalent_types_{dtype.__name__}", test_equivalent_types, devices=devices, dtype=dtype
2878
+ )
2879
+ add_function_test_register_kernel(
2880
+ TestMatScalarOps, f"test_conversions_{dtype.__name__}", test_conversions, devices=devices, dtype=dtype
2881
+ )
2882
+ add_function_test_register_kernel(
2883
+ TestMatScalarOps, f"test_constants_{dtype.__name__}", test_constants, devices=devices, dtype=dtype
2884
+ )
2885
+
2886
+
2887
+ if __name__ == "__main__":
2888
+ wp.build.clear_kernel_cache()
2889
+ unittest.main(verbosity=2, failfast=True)