warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/tests/test_mat.py CHANGED
@@ -5,9 +5,12 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import unittest
9
+
8
10
  import numpy as np
11
+
9
12
  import warp as wp
10
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
11
14
 
12
15
  wp.init()
13
16
 
@@ -19,37 +22,24 @@ np_signed_int_types = [
19
22
  np.byte,
20
23
  ]
21
24
 
22
- np_unsigned_int_types = [
23
- np.uint8,
24
- np.uint16,
25
- np.uint32,
26
- np.uint64,
27
- np.ubyte,
28
- ]
29
-
30
- np_int_types = np_signed_int_types + np_unsigned_int_types
31
-
32
25
  np_float_types = [np.float16, np.float32, np.float64]
33
26
 
34
- np_scalar_types = np_int_types + np_float_types
35
-
36
27
 
37
- def randvals(shape, dtype):
28
+ def randvals(rng, shape, dtype):
38
29
  if dtype in np_float_types:
39
- return np.random.randn(*shape).astype(dtype)
30
+ return rng.standard_normal(size=shape).astype(dtype)
40
31
  elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
41
- return np.random.randint(1, 3, size=shape, dtype=dtype)
42
- return np.random.randint(1, 5, size=shape, dtype=dtype)
32
+ return rng.integers(1, high=3, size=shape, dtype=dtype)
33
+ return rng.integers(1, high=5, size=shape, dtype=dtype)
43
34
 
44
35
 
45
36
  kernel_cache = dict()
46
37
 
47
38
 
48
39
  def getkernel(func, suffix=""):
49
- module = wp.get_module(func.__module__)
50
40
  key = func.__name__ + "_" + suffix
51
41
  if key not in kernel_cache:
52
- kernel_cache[key] = wp.Kernel(func=func, key=key, module=module)
42
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
53
43
  return kernel_cache[key]
54
44
 
55
45
 
@@ -63,324 +53,224 @@ def get_select_kernel(dtype):
63
53
 
64
54
  return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
65
55
 
56
+ wp.launch(kernel, dim=1, inputs=[])
66
57
 
67
- def test_arrays(test, device, dtype):
68
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
69
-
70
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
71
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
72
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
73
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
74
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
75
58
 
76
- np.random.seed(123)
59
+ def test_anon_constructor_error_shape_keyword_missing(test, device):
60
+ @wp.kernel
61
+ def kernel():
62
+ wp.matrix(1.0, 2.0, 3.0)
77
63
 
78
- v2_np = randvals([10, 2, 2], dtype)
79
- v3_np = randvals([10, 3, 3], dtype)
80
- v4_np = randvals([10, 4, 4], dtype)
81
- v5_np = randvals([10, 5, 5], dtype)
82
- v32_np = randvals([10, 3, 2], dtype)
64
+ with test.assertRaisesRegex(
65
+ RuntimeError,
66
+ r"shape keyword must be specified when calling matrix\(\) function$",
67
+ ):
68
+ wp.launch(
69
+ kernel,
70
+ dim=1,
71
+ inputs=[],
72
+ device=device,
73
+ )
83
74
 
84
- v2 = wp.array(v2_np, dtype=mat22, requires_grad=True, device=device)
85
- v3 = wp.array(v3_np, dtype=mat33, requires_grad=True, device=device)
86
- v4 = wp.array(v4_np, dtype=mat44, requires_grad=True, device=device)
87
- v5 = wp.array(v5_np, dtype=mat55, requires_grad=True, device=device)
88
- v32 = wp.array(v32_np, dtype=mat32, requires_grad=True, device=device)
89
75
 
90
- assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
91
- assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
92
- assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
93
- assert_np_equal(v5.numpy(), v5_np, tol=1.0e-6)
94
- assert_np_equal(v32.numpy(), v32_np, tol=1.0e-6)
76
+ def test_anon_constructor_error_dtype_keyword_missing(test, device):
77
+ @wp.kernel
78
+ def kernel():
79
+ wp.matrix(shape=(3, 3))
95
80
 
96
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
97
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
98
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
81
+ with test.assertRaisesRegex(
82
+ RuntimeError,
83
+ r"matrix\(\) must have dtype as a keyword argument if it has no " r"positional arguments$",
84
+ ):
85
+ wp.launch(
86
+ kernel,
87
+ dim=1,
88
+ inputs=[],
89
+ device=device,
90
+ )
99
91
 
100
- v2 = wp.array(v2_np, dtype=mat22, requires_grad=True, device=device)
101
- v3 = wp.array(v3_np, dtype=mat33, requires_grad=True, device=device)
102
- v4 = wp.array(v4_np, dtype=mat44, requires_grad=True, device=device)
103
92
 
104
- assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
105
- assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
106
- assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
93
+ def test_anon_constructor_error_shape_mismatch(test, device):
94
+ @wp.kernel
95
+ def kernel():
96
+ wp.matrix(
97
+ wp.matrix(shape=(1, 2), dtype=float),
98
+ shape=(3, 4),
99
+ dtype=float,
100
+ )
107
101
 
102
+ with test.assertRaisesRegex(
103
+ RuntimeError,
104
+ r"Incompatible matrix sizes for casting copy constructor, " r"\(3, 4\) vs \(1, 2\)$",
105
+ ):
106
+ wp.launch(
107
+ kernel,
108
+ dim=1,
109
+ inputs=[],
110
+ device=device,
111
+ )
108
112
 
109
- def test_constants(test, device, dtype, register_kernels=False):
110
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
111
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
112
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
113
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
114
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
115
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
116
113
 
117
- cm22 = wp.constant(mat22(22))
118
- cm33 = wp.constant(mat33(33))
119
- cm44 = wp.constant(mat44(44))
120
- cm55 = wp.constant(mat55(55))
121
- cm32 = wp.constant(mat32(32))
114
+ def test_anon_constructor_error_invalid_arg_count(test, device):
115
+ @wp.kernel
116
+ def kernel():
117
+ wp.matrix(1.0, 2.0, 3.0, shape=(2, 2), dtype=float)
122
118
 
123
- def check_matrix_constants():
124
- wp.expect_eq(cm22, mat22(wptype(22)))
125
- wp.expect_eq(cm33, mat33(wptype(33)))
126
- wp.expect_eq(cm44, mat44(wptype(44)))
127
- wp.expect_eq(cm55, mat55(wptype(55)))
128
- wp.expect_eq(cm32, mat32(wptype(32)))
119
+ with test.assertRaisesRegex(
120
+ RuntimeError,
121
+ r"Wrong number of arguments for matrix\(\) function, must initialize "
122
+ r"with either a scalar value, or m\*n values$",
123
+ ):
124
+ wp.launch(
125
+ kernel,
126
+ dim=1,
127
+ inputs=[],
128
+ device=device,
129
+ )
129
130
 
130
- kernel = getkernel(check_matrix_constants, suffix=dtype.__name__)
131
131
 
132
- if register_kernels:
133
- return
132
+ def test_tpl_constructor_error_incompatible_sizes(test, device):
133
+ @wp.kernel
134
+ def kernel():
135
+ wp.mat33(wp.mat22(1.0, 2.0, 3.0, 4.0))
134
136
 
135
- wp.launch(kernel, dim=1, inputs=[])
137
+ with test.assertRaisesRegex(
138
+ RuntimeError,
139
+ r"Incompatible matrix sizes for casting copy constructor, " r"\(3, 3\) vs \(2, 2\)$",
140
+ ):
141
+ wp.launch(
142
+ kernel,
143
+ dim=1,
144
+ inputs=[],
145
+ device=device,
146
+ )
136
147
 
137
148
 
138
- def test_constructors(test, device, dtype, register_kernels=False):
139
- np.random.seed(123)
149
+ def test_tpl_constructor_error_invalid_scalar_type(test, device):
150
+ @wp.kernel
151
+ def kernel():
152
+ wp.mat22(1, 2, 3, 4)
140
153
 
141
- tol = {
142
- np.float16: 1.0e-3,
143
- np.float32: 1.0e-6,
144
- np.float64: 1.0e-8,
145
- }.get(dtype, 0)
154
+ with test.assertRaisesRegex(
155
+ RuntimeError,
156
+ r"Wrong scalar type for mat 2,2,<class 'warp.types.float32'> constructor$",
157
+ ):
158
+ wp.launch(
159
+ kernel,
160
+ dim=1,
161
+ inputs=[],
162
+ device=device,
163
+ )
146
164
 
147
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
148
- vec2 = wp.types.vector(length=2, dtype=wptype)
149
- vec3 = wp.types.vector(length=3, dtype=wptype)
150
- vec4 = wp.types.vector(length=4, dtype=wptype)
151
- vec5 = wp.types.vector(length=5, dtype=wptype)
152
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
153
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
154
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
155
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
156
165
 
157
- output_select_kernel = get_select_kernel(wptype)
166
+ def test_tpl_constructor_error_invalid_vector_count(test, device):
167
+ @wp.kernel
168
+ def kernel():
169
+ wp.mat22(wp.vec3(1.0, 2.0, 3.0))
158
170
 
159
- def check_scalar_mat_constructor(
160
- input: wp.array(dtype=wptype),
161
- outcomponents: wp.array(dtype=wptype),
171
+ with test.assertRaisesRegex(
172
+ RuntimeError,
173
+ r"Wrong number of vectors when attempting to construct a matrix " r"with column vectors$",
162
174
  ):
163
- # multiply outputs by 2 so we've got something to backpropagate:
164
- m2result = wptype(2) * mat22(input[0])
165
- m3result = wptype(2) * mat33(input[0])
166
- m4result = wptype(2) * mat44(input[0])
167
- m5result = wptype(2) * mat55(input[0])
168
-
169
- idx = 0
170
- for i in range(2):
171
- for j in range(2):
172
- outcomponents[idx] = m2result[i, j]
173
- idx = idx + 1
174
-
175
- for i in range(3):
176
- for j in range(3):
177
- outcomponents[idx] = m3result[i, j]
178
- idx = idx + 1
175
+ wp.launch(
176
+ kernel,
177
+ dim=1,
178
+ inputs=[],
179
+ device=device,
180
+ )
179
181
 
180
- for i in range(4):
181
- for j in range(4):
182
- outcomponents[idx] = m4result[i, j]
183
- idx = idx + 1
184
182
 
185
- for i in range(5):
186
- for j in range(5):
187
- outcomponents[idx] = m5result[i, j]
188
- idx = idx + 1
183
+ def test_tpl_constructor_error_invalid_vector_shape(test, device):
184
+ @wp.kernel
185
+ def kernel():
186
+ wp.mat22(wp.vec3(1.0, 2.0, 3.0), wp.vec3(4.0, 5.0, 6.0))
189
187
 
190
- def check_component_mat_constructor(
191
- input: wp.array(dtype=wptype),
192
- outcomponents: wp.array(dtype=wptype),
188
+ with test.assertRaisesRegex(
189
+ RuntimeError,
190
+ r"Wrong vector row count when attempting to construct a matrix " r"with column vectors$",
193
191
  ):
194
- # multiply outputs by 2 so we've got something to backpropagate:
195
- m2result = wptype(2) * mat22(input[0], input[1], input[2], input[3])
196
- m3result = wptype(2) * mat33(
197
- input[4],
198
- input[5],
199
- input[6],
200
- input[7],
201
- input[8],
202
- input[9],
203
- input[10],
204
- input[11],
205
- input[12],
206
- )
207
- m4result = wptype(2) * mat44(
208
- input[13],
209
- input[14],
210
- input[15],
211
- input[16],
212
- input[17],
213
- input[18],
214
- input[19],
215
- input[20],
216
- input[21],
217
- input[22],
218
- input[23],
219
- input[24],
220
- input[25],
221
- input[26],
222
- input[27],
223
- input[28],
224
- )
225
- m5result = wptype(2) * mat55(
226
- input[29],
227
- input[30],
228
- input[31],
229
- input[32],
230
- input[33],
231
- input[34],
232
- input[35],
233
- input[36],
234
- input[37],
235
- input[38],
236
- input[39],
237
- input[40],
238
- input[41],
239
- input[42],
240
- input[43],
241
- input[44],
242
- input[45],
243
- input[46],
244
- input[47],
245
- input[48],
246
- input[49],
247
- input[50],
248
- input[51],
249
- input[52],
250
- input[53],
192
+ wp.launch(
193
+ kernel,
194
+ dim=1,
195
+ inputs=[],
196
+ device=device,
251
197
  )
252
198
 
253
- idx = 0
254
- for i in range(2):
255
- for j in range(2):
256
- outcomponents[idx] = m2result[i, j]
257
- idx = idx + 1
258
-
259
- for i in range(3):
260
- for j in range(3):
261
- outcomponents[idx] = m3result[i, j]
262
- idx = idx + 1
263
-
264
- for i in range(4):
265
- for j in range(4):
266
- outcomponents[idx] = m4result[i, j]
267
- idx = idx + 1
268
199
 
269
- for i in range(5):
270
- for j in range(5):
271
- outcomponents[idx] = m5result[i, j]
272
- idx = idx + 1
200
+ def test_tpl_constructor_error_invalid_arg_count(test, device):
201
+ @wp.kernel
202
+ def kernel():
203
+ wp.mat22(1.0, 2.0, 3.0)
273
204
 
274
- def check_vector_mat_constructor(
275
- input: wp.array(dtype=wptype),
276
- outcomponents: wp.array(dtype=wptype),
205
+ with test.assertRaisesRegex(
206
+ RuntimeError,
207
+ r"Wrong number of scalars when attempting to construct a matrix " r"from a list of components$",
277
208
  ):
278
- # multiply outputs by 2 so we've got something to backpropagate:
279
- m2result = wptype(2) * mat22(vec2(input[0], input[2]), vec2(input[1], input[3]))
280
- m3result = wptype(2) * mat33(
281
- vec3(input[4], input[7], input[10]),
282
- vec3(input[5], input[8], input[11]),
283
- vec3(input[6], input[9], input[12]),
284
- )
285
- m4result = wptype(2) * mat44(
286
- vec4(input[13], input[17], input[21], input[25]),
287
- vec4(input[14], input[18], input[22], input[26]),
288
- vec4(input[15], input[19], input[23], input[27]),
289
- vec4(input[16], input[20], input[24], input[28]),
290
- )
291
- m5result = wptype(2) * mat55(
292
- vec5(input[29], input[34], input[39], input[44], input[49]),
293
- vec5(input[30], input[35], input[40], input[45], input[50]),
294
- vec5(input[31], input[36], input[41], input[46], input[51]),
295
- vec5(input[32], input[37], input[42], input[47], input[52]),
296
- vec5(input[33], input[38], input[43], input[48], input[53]),
209
+ wp.launch(
210
+ kernel,
211
+ dim=1,
212
+ inputs=[],
213
+ device=device,
297
214
  )
298
215
 
299
- idx = 0
300
- for i in range(2):
301
- for j in range(2):
302
- outcomponents[idx] = m2result[i, j]
303
- idx = idx + 1
304
-
305
- for i in range(3):
306
- for j in range(3):
307
- outcomponents[idx] = m3result[i, j]
308
- idx = idx + 1
309
-
310
- for i in range(4):
311
- for j in range(4):
312
- outcomponents[idx] = m4result[i, j]
313
- idx = idx + 1
314
216
 
315
- for i in range(5):
316
- for j in range(5):
317
- outcomponents[idx] = m5result[i, j]
318
- idx = idx + 1
217
+ def test_tpl_ops_with_anon(test, device):
218
+ mat22f = wp.mat((2, 2), dtype=float)
319
219
 
320
- kernel = getkernel(check_scalar_mat_constructor, suffix=dtype.__name__)
321
- compkernel = getkernel(check_component_mat_constructor, suffix=dtype.__name__)
322
- veckernel = getkernel(check_vector_mat_constructor, suffix=dtype.__name__)
220
+ m = wp.mat22f(1.0, 2.0, 3.0, 4.0)
221
+ m += mat22f(2.0, 3.0, 4.0, 5.0)
222
+ m -= mat22f(3.0, 4.0, 5.0, 6.0)
223
+ test.assertSequenceEqual(m, ((0.0, 1.0), (2.0, 3.0)))
323
224
 
324
- if register_kernels:
325
- return
225
+ m = mat22f(1.0, 2.0, 3.0, 4.0)
226
+ m += wp.mat22f(2.0, 3.0, 4.0, 5.0)
227
+ m -= wp.mat22f(3.0, 4.0, 5.0, 6.0)
228
+ test.assertSequenceEqual(m, ((0.0, 1.0), (2.0, 3.0)))
326
229
 
327
- input = wp.array(randvals([1], dtype), requires_grad=True, device=device)
328
- val = input.numpy()[0]
329
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
330
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
331
230
 
332
- wp.launch(kernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
231
+ def test_py_arithmetic_ops(test, device, dtype):
232
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
333
233
 
334
- assert_np_equal(outcomponents.numpy()[:4], 2 * val * np.ones(2 * 2), tol=tol)
335
- assert_np_equal(outcomponents.numpy()[4:13], 2 * val * np.ones(3 * 3), tol=tol)
336
- assert_np_equal(outcomponents.numpy()[13:29], 2 * val * np.ones(4 * 4), tol=tol)
337
- assert_np_equal(outcomponents.numpy()[29:54], 2 * val * np.ones(5 * 5), tol=tol)
234
+ def make_mat(*args):
235
+ if wptype in wp.types.int_types:
236
+ # Cast to the correct integer type to simulate wrapping.
237
+ return tuple(tuple(wptype._type_(x).value for x in row) for row in args)
338
238
 
339
- if dtype in np_float_types:
340
- for idx in range(len(outcomponents)):
341
- tape = wp.Tape()
342
- with tape:
343
- wp.launch(kernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
344
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
345
- tape.backward(loss=out)
346
- test.assertEqual(tape.gradients[input].numpy()[0], 2)
347
- tape.zero()
239
+ return args
348
240
 
349
- input = wp.array(randvals([2 * 2 + 3 * 3 + 4 * 4 + 5 * 5], dtype), requires_grad=True, device=device)
241
+ def make_vec(*args):
242
+ if wptype in wp.types.int_types:
243
+ # Cast to the correct integer type to simulate wrapping.
244
+ return tuple(wptype._type_(x).value for x in args)
350
245
 
351
- wp.launch(compkernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
352
- assert_np_equal(2 * input.numpy(), outcomponents.numpy(), tol=10 * tol)
246
+ return args
353
247
 
354
- if dtype in np_float_types:
355
- for idx in range(len(outcomponents)):
356
- tape = wp.Tape()
357
- with tape:
358
- wp.launch(compkernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
359
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
360
- tape.backward(loss=out)
361
- expectedgrads = np.zeros(len(input))
362
- expectedgrads[idx] = 2
363
- assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
364
- tape.zero()
248
+ mat_cls = wp.mat((3, 3), wptype)
249
+ vec_cls = wp.vec(3, wptype)
365
250
 
366
- wp.launch(veckernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
367
- assert_np_equal(2 * input.numpy(), outcomponents.numpy(), tol=10 * tol)
251
+ m = mat_cls(((-1, 2, 3), (4, -5, 6), (7, 8, -9)))
252
+ test.assertSequenceEqual(+m, make_mat((-1, 2, 3), (4, -5, 6), (7, 8, -9)))
253
+ test.assertSequenceEqual(-m, make_mat((1, -2, -3), (-4, 5, -6), (-7, -8, 9)))
254
+ test.assertSequenceEqual(m + mat_cls((5, 5, 5) * 3), make_mat((4, 7, 8), (9, 0, 11), (12, 13, -4)))
255
+ test.assertSequenceEqual(m - mat_cls((5, 5, 5) * 3), make_mat((-6, -3, -2), (-1, -10, 1), (2, 3, -14)))
256
+ test.assertSequenceEqual(m * vec_cls(5, 5, 5), make_vec(20, 25, 30))
257
+ test.assertSequenceEqual(m @ vec_cls(5, 5, 5), make_vec(20, 25, 30))
258
+ test.assertSequenceEqual(vec_cls(5, 5, 5) * m, make_vec(50, 25, 0))
259
+ test.assertSequenceEqual(vec_cls(5, 5, 5) @ m, make_vec(50, 25, 0))
368
260
 
369
- if dtype in np_float_types:
370
- for idx in range(len(outcomponents)):
371
- tape = wp.Tape()
372
- with tape:
373
- wp.launch(veckernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
374
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
375
- tape.backward(loss=out)
376
- expectedgrads = np.zeros(len(input))
377
- expectedgrads[idx] = 2
378
- assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
379
- tape.zero()
261
+ m = mat_cls(((2, 4, 6), (8, 10, 12), (14, 16, 18)))
262
+ test.assertSequenceEqual(m * wptype(2), make_mat((4, 8, 12), (16, 20, 24), (28, 32, 36)))
263
+ test.assertSequenceEqual(wptype(2) * m, make_mat((4, 8, 12), (16, 20, 24), (28, 32, 36)))
264
+ test.assertSequenceEqual(m / wptype(2), make_mat((1, 2, 3), (4, 5, 6), (7, 8, 9)))
265
+ test.assertSequenceEqual(wptype(5040) / m, make_mat((2520, 1260, 840), (630, 504, 420), (360, 315, 280)))
266
+ test.assertSequenceEqual(m * vec_cls(5, 5, 5), make_vec(60, 150, 240))
267
+ test.assertSequenceEqual(m @ vec_cls(5, 5, 5), make_vec(60, 150, 240))
268
+ test.assertSequenceEqual(vec_cls(5, 5, 5) * m, make_vec(120, 150, 180))
269
+ test.assertSequenceEqual(vec_cls(5, 5, 5) @ m, make_vec(120, 150, 180))
380
270
 
381
271
 
382
272
  def test_quat_constructor(test, device, dtype, register_kernels=False):
383
- np.random.seed(123)
273
+ rng = np.random.default_rng(123)
384
274
 
385
275
  tol = {
386
276
  np.float16: 1.0e-3,
@@ -429,15 +319,15 @@ def test_quat_constructor(test, device, dtype, register_kernels=False):
429
319
  return
430
320
 
431
321
  # translation:
432
- p = wp.array(np.random.randn(1, 3).astype(dtype), dtype=vec3, requires_grad=True, device=device)
322
+ p = wp.array(rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
433
323
 
434
324
  # generate a normalized quaternion for the rotation:
435
- r = np.random.randn(1, 4)
325
+ r = rng.standard_normal(size=(1, 4))
436
326
  r /= np.linalg.norm(r)
437
327
  r = wp.array(r.astype(dtype), dtype=quat, requires_grad=True, device=device)
438
328
 
439
329
  # scale:
440
- s = wp.array(np.random.randn(1, 3).astype(dtype), dtype=vec3, requires_grad=True, device=device)
330
+ s = wp.array(rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
441
331
 
442
332
  # just going to generate the matrix using the constructor, then
443
333
  # more manually, and make sure the values/gradients are the same:
@@ -478,95 +368,11 @@ def test_quat_constructor(test, device, dtype, register_kernels=False):
478
368
  idx = idx + 1
479
369
 
480
370
 
481
- def test_indexing(test, device, dtype, register_kernels=False):
482
- np.random.seed(123)
483
-
484
- tol = {
485
- np.float16: 1.0e-3,
486
- np.float32: 1.0e-6,
487
- np.float64: 1.0e-8,
488
- }.get(dtype, 0)
489
-
490
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
491
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
492
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
493
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
494
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
495
-
496
- output_select_kernel = get_select_kernel(wptype)
497
-
498
- def check_mat_indexing(
499
- m2: wp.array(dtype=mat22),
500
- m3: wp.array(dtype=mat33),
501
- m4: wp.array(dtype=mat44),
502
- m5: wp.array(dtype=mat55),
503
- outcomponents: wp.array(dtype=wptype),
504
- ):
505
- # multiply outputs by 2 so we've got something to backpropagate:
506
- idx = 0
507
- for i in range(2):
508
- for j in range(2):
509
- outcomponents[idx] = wptype(2) * m2[0][i, j]
510
- idx = idx + 1
511
-
512
- for i in range(3):
513
- for j in range(3):
514
- outcomponents[idx] = wptype(2) * m3[0][i, j]
515
- idx = idx + 1
516
-
517
- for i in range(4):
518
- for j in range(4):
519
- outcomponents[idx] = wptype(2) * m4[0][i, j]
520
- idx = idx + 1
521
-
522
- for i in range(5):
523
- for j in range(5):
524
- outcomponents[idx] = wptype(2) * m5[0][i, j]
525
- idx = idx + 1
526
-
527
- kernel = getkernel(check_mat_indexing, suffix=dtype.__name__)
528
-
529
- if register_kernels:
530
- return
531
-
532
- m2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
533
- m3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
534
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
535
- m5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
536
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
537
-
538
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
539
-
540
- assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy().reshape(-1), tol=tol)
541
- assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy().reshape(-1), tol=tol)
542
- assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy().reshape(-1), tol=tol)
543
- assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy().reshape(-1), tol=tol)
544
-
545
- if dtype in np_float_types:
546
- idx = 0
547
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
548
- for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
549
- for i in range(dim):
550
- for j in range(dim):
551
- tape = wp.Tape()
552
- with tape:
553
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
554
- wp.launch(
555
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
556
- )
557
- tape.backward(loss=out)
558
- expectedresult = np.zeros((dim, dim), dtype=dtype)
559
- expectedresult[i, j] = 2
560
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
561
- tape.zero()
562
- idx = idx + 1
563
-
564
-
565
- def test_equality(test, device, dtype, register_kernels=False):
566
- np.random.seed(123)
371
+ def test_negation(test, device, dtype, register_kernels=False):
372
+ rng = np.random.default_rng(123)
567
373
 
568
374
  tol = {
569
- np.float16: 1.0e-3,
375
+ np.float16: 1.0e-2,
570
376
  np.float32: 1.0e-6,
571
377
  np.float64: 1.0e-8,
572
378
  }.get(dtype, 0)
@@ -574,1614 +380,85 @@ def test_equality(test, device, dtype, register_kernels=False):
574
380
  wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
575
381
  mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
576
382
  mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
577
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
578
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
579
-
580
- def check_mat_equality():
581
- wp.expect_eq(
582
- mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
583
- mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
584
- )
585
- wp.expect_neq(
586
- mat22(wptype(1.0), wptype(2.0), wptype(3.0), -wptype(4.0)),
587
- mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
588
- )
589
-
590
- wp.expect_eq(
591
- mat33(
592
- wptype(1.0),
593
- wptype(2.0),
594
- wptype(3.0),
595
- wptype(4.0),
596
- wptype(5.0),
597
- wptype(6.0),
598
- wptype(7.0),
599
- wptype(8.0),
600
- wptype(9.0),
601
- ),
602
- mat33(
603
- wptype(1.0),
604
- wptype(2.0),
605
- wptype(3.0),
606
- wptype(4.0),
607
- wptype(5.0),
608
- wptype(6.0),
609
- wptype(7.0),
610
- wptype(8.0),
611
- wptype(9.0),
612
- ),
613
- )
614
- wp.expect_neq(
615
- mat33(
616
- wptype(1.0),
617
- wptype(2.0),
618
- wptype(3.0),
619
- wptype(4.0),
620
- wptype(5.0),
621
- wptype(6.0),
622
- wptype(7.0),
623
- wptype(8.0),
624
- wptype(9.0),
625
- ),
626
- mat33(
627
- wptype(1.0),
628
- wptype(2.0),
629
- wptype(3.0),
630
- -wptype(4.0),
631
- wptype(5.0),
632
- wptype(6.0),
633
- wptype(7.0),
634
- wptype(8.0),
635
- wptype(9.0),
636
- ),
637
- )
638
-
639
- wp.expect_eq(
640
- mat44(
641
- wptype(1.0),
642
- wptype(2.0),
643
- wptype(3.0),
644
- wptype(4.0),
645
- wptype(5.0),
646
- wptype(6.0),
647
- wptype(7.0),
648
- wptype(8.0),
649
- wptype(9.0),
650
- wptype(10.0),
651
- wptype(11.0),
652
- wptype(12.0),
653
- wptype(13.0),
654
- wptype(14.0),
655
- wptype(15.0),
656
- wptype(16.0),
657
- ),
658
- mat44(
659
- wptype(1.0),
660
- wptype(2.0),
661
- wptype(3.0),
662
- wptype(4.0),
663
- wptype(5.0),
664
- wptype(6.0),
665
- wptype(7.0),
666
- wptype(8.0),
667
- wptype(9.0),
668
- wptype(10.0),
669
- wptype(11.0),
670
- wptype(12.0),
671
- wptype(13.0),
672
- wptype(14.0),
673
- wptype(15.0),
674
- wptype(16.0),
675
- ),
676
- )
677
-
678
- wp.expect_neq(
679
- mat44(
680
- wptype(1.0),
681
- wptype(2.0),
682
- wptype(3.0),
683
- wptype(4.0),
684
- wptype(5.0),
685
- wptype(6.0),
686
- wptype(7.0),
687
- wptype(8.0),
688
- wptype(9.0),
689
- wptype(10.0),
690
- wptype(11.0),
691
- wptype(12.0),
692
- wptype(13.0),
693
- wptype(14.0),
694
- wptype(15.0),
695
- wptype(16.0),
696
- ),
697
- mat44(
698
- -wptype(1.0),
699
- wptype(2.0),
700
- wptype(3.0),
701
- wptype(4.0),
702
- wptype(5.0),
703
- wptype(6.0),
704
- wptype(7.0),
705
- wptype(8.0),
706
- wptype(9.0),
707
- wptype(10.0),
708
- wptype(11.0),
709
- wptype(12.0),
710
- wptype(13.0),
711
- wptype(14.0),
712
- wptype(15.0),
713
- wptype(16.0),
714
- ),
715
- )
716
-
717
- wp.expect_eq(
718
- mat55(
719
- wptype(1.0),
720
- wptype(2.0),
721
- wptype(3.0),
722
- wptype(4.0),
723
- wptype(5.0),
724
- wptype(6.0),
725
- wptype(7.0),
726
- wptype(8.0),
727
- wptype(9.0),
728
- wptype(10.0),
729
- wptype(11.0),
730
- wptype(12.0),
731
- wptype(13.0),
732
- wptype(14.0),
733
- wptype(15.0),
734
- wptype(16.0),
735
- wptype(17.0),
736
- wptype(18.0),
737
- wptype(19.0),
738
- wptype(20.0),
739
- wptype(21.0),
740
- wptype(22.0),
741
- wptype(23.0),
742
- wptype(24.0),
743
- wptype(25.0),
744
- ),
745
- mat55(
746
- wptype(1.0),
747
- wptype(2.0),
748
- wptype(3.0),
749
- wptype(4.0),
750
- wptype(5.0),
751
- wptype(6.0),
752
- wptype(7.0),
753
- wptype(8.0),
754
- wptype(9.0),
755
- wptype(10.0),
756
- wptype(11.0),
757
- wptype(12.0),
758
- wptype(13.0),
759
- wptype(14.0),
760
- wptype(15.0),
761
- wptype(16.0),
762
- wptype(17.0),
763
- wptype(18.0),
764
- wptype(19.0),
765
- wptype(20.0),
766
- wptype(21.0),
767
- wptype(22.0),
768
- wptype(23.0),
769
- wptype(24.0),
770
- wptype(25.0),
771
- ),
772
- )
773
-
774
- wp.expect_neq(
775
- mat55(
776
- wptype(1.0),
777
- wptype(2.0),
778
- wptype(3.0),
779
- wptype(4.0),
780
- wptype(5.0),
781
- wptype(6.0),
782
- wptype(7.0),
783
- wptype(8.0),
784
- wptype(9.0),
785
- wptype(10.0),
786
- wptype(11.0),
787
- wptype(12.0),
788
- wptype(13.0),
789
- wptype(14.0),
790
- wptype(15.0),
791
- wptype(16.0),
792
- wptype(17.0),
793
- wptype(18.0),
794
- wptype(19.0),
795
- wptype(20.0),
796
- wptype(21.0),
797
- wptype(22.0),
798
- wptype(23.0),
799
- wptype(24.0),
800
- wptype(25.0),
801
- ),
802
- mat55(
803
- wptype(1.0),
804
- wptype(2.0),
805
- wptype(3.0),
806
- wptype(4.0),
807
- wptype(5.0),
808
- wptype(6.0),
809
- wptype(7.0),
810
- wptype(8.0),
811
- wptype(9.0),
812
- wptype(10.0),
813
- wptype(11.0),
814
- wptype(12.0),
815
- wptype(13.0),
816
- wptype(14.0),
817
- wptype(15.0),
818
- wptype(16.0),
819
- -wptype(17.0),
820
- wptype(18.0),
821
- wptype(19.0),
822
- wptype(20.0),
823
- wptype(21.0),
824
- wptype(22.0),
825
- wptype(23.0),
826
- wptype(24.0),
827
- wptype(25.0),
828
- ),
829
- )
830
-
831
- kernel = getkernel(check_mat_equality, suffix=dtype.__name__)
832
-
833
- if register_kernels:
834
- return
835
-
836
- wp.launch(kernel, dim=1, inputs=[], outputs=[], device=device)
837
-
838
-
839
- def test_negation(test, device, dtype, register_kernels=False):
840
- np.random.seed(123)
841
-
842
- tol = {
843
- np.float16: 1.0e-2,
844
- np.float32: 1.0e-6,
845
- np.float64: 1.0e-8,
846
- }.get(dtype, 0)
847
-
848
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
849
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
850
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
851
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
852
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
853
-
854
- output_select_kernel = get_select_kernel(wptype)
855
-
856
- def check_mat_negation(
857
- m2: wp.array(dtype=mat22),
858
- m3: wp.array(dtype=mat33),
859
- m4: wp.array(dtype=mat44),
860
- m5: wp.array(dtype=mat55),
861
- outcomponents: wp.array(dtype=wptype),
862
- ):
863
- mat2 = -m2[0]
864
- mat3 = -m3[0]
865
- mat4 = -m4[0]
866
- mat5 = -m5[0]
867
-
868
- # multiply outputs by 2 so we've got something to backpropagate:
869
- idx = 0
870
- for i in range(2):
871
- for j in range(2):
872
- outcomponents[idx] = wptype(2) * mat2[i, j]
873
- idx = idx + 1
874
-
875
- for i in range(3):
876
- for j in range(3):
877
- outcomponents[idx] = wptype(2) * mat3[i, j]
878
- idx = idx + 1
879
-
880
- for i in range(4):
881
- for j in range(4):
882
- outcomponents[idx] = wptype(2) * mat4[i, j]
883
- idx = idx + 1
884
-
885
- for i in range(5):
886
- for j in range(5):
887
- outcomponents[idx] = wptype(2) * mat5[i, j]
888
- idx = idx + 1
889
-
890
- kernel = getkernel(check_mat_negation, suffix=dtype.__name__)
891
-
892
- if register_kernels:
893
- return
894
-
895
- m2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
896
- m3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
897
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
898
- m5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
899
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
900
-
901
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
902
-
903
- assert_np_equal(outcomponents.numpy()[:4], -2 * m2.numpy().reshape(-1), tol=tol)
904
- assert_np_equal(outcomponents.numpy()[4:13], -2 * m3.numpy().reshape(-1), tol=tol)
905
- assert_np_equal(outcomponents.numpy()[13:29], -2 * m4.numpy().reshape(-1), tol=tol)
906
- assert_np_equal(outcomponents.numpy()[29:54], -2 * m5.numpy().reshape(-1), tol=tol)
907
-
908
- if dtype in np_float_types:
909
- idx = 0
910
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
911
- for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
912
- for i in range(dim):
913
- for j in range(dim):
914
- tape = wp.Tape()
915
- with tape:
916
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
917
- wp.launch(
918
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
919
- )
920
- tape.backward(loss=out)
921
- expectedresult = np.zeros((dim, dim), dtype=dtype)
922
- expectedresult[i, j] = -2
923
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
924
- tape.zero()
925
- idx = idx + 1
926
-
927
-
928
- def test_transpose(test, device, dtype, register_kernels=False):
929
- np.random.seed(123)
930
-
931
- tol = {
932
- np.float16: 1.0e-2,
933
- np.float32: 1.0e-6,
934
- np.float64: 1.0e-8,
935
- }.get(dtype, 0)
936
-
937
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
938
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
939
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
940
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
941
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
942
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
943
-
944
- output_select_kernel = get_select_kernel(wptype)
945
-
946
- def check_mat_transpose(
947
- m2: wp.array(dtype=mat22),
948
- m3: wp.array(dtype=mat33),
949
- m4: wp.array(dtype=mat44),
950
- m5: wp.array(dtype=mat55),
951
- m32: wp.array(dtype=mat32),
952
- outcomponents: wp.array(dtype=wptype),
953
- ):
954
- # multiply outputs by 2 so we've got something to backpropagate:
955
- mat2 = wptype(2) * wp.transpose(m2[0])
956
- mat3 = wptype(2) * wp.transpose(m3[0])
957
- mat4 = wptype(2) * wp.transpose(m4[0])
958
- mat5 = wptype(2) * wp.transpose(m5[0])
959
- mat32 = wptype(2) * wp.transpose(m32[0])
960
-
961
- idx = 0
962
- for i in range(2):
963
- for j in range(2):
964
- outcomponents[idx] = mat2[i, j]
965
- idx = idx + 1
966
-
967
- for i in range(3):
968
- for j in range(3):
969
- outcomponents[idx] = mat3[i, j]
970
- idx = idx + 1
971
-
972
- for i in range(4):
973
- for j in range(4):
974
- outcomponents[idx] = mat4[i, j]
975
- idx = idx + 1
976
-
977
- for i in range(5):
978
- for j in range(5):
979
- outcomponents[idx] = mat5[i, j]
980
- idx = idx + 1
981
-
982
- for i in range(2):
983
- for j in range(3):
984
- outcomponents[idx] = mat32[i, j]
985
- idx = idx + 1
986
-
987
- kernel = getkernel(check_mat_transpose, suffix=dtype.__name__)
988
-
989
- if register_kernels:
990
- return
991
-
992
- m2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
993
- m3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
994
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
995
- m5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
996
- m32 = wp.array(randvals([1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
997
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 2 * 3, dtype=wptype, requires_grad=True, device=device)
998
-
999
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1000
-
1001
- assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy()[0].T.reshape(-1), tol=tol)
1002
- assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy()[0].T.reshape(-1), tol=tol)
1003
- assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy()[0].T.reshape(-1), tol=tol)
1004
- assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy()[0].T.reshape(-1), tol=tol)
1005
- assert_np_equal(outcomponents.numpy()[54:], 2 * m32.numpy()[0].T.reshape(-1), tol=tol)
1006
-
1007
- if dtype in np_float_types:
1008
- idx = 0
1009
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1010
- for input in [m2, m3, m4, m5]:
1011
- for i in range(input.dtype._shape_[0]):
1012
- for j in range(input.dtype._shape_[1]):
1013
- tape = wp.Tape()
1014
- with tape:
1015
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1016
- wp.launch(
1017
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1018
- )
1019
- tape.backward(loss=out)
1020
- expectedresult = np.zeros((input.dtype._shape_[1], input.dtype._shape_[0]), dtype=dtype)
1021
- expectedresult[j, i] = 2
1022
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
1023
- tape.zero()
1024
- idx = idx + 1
1025
-
1026
-
1027
- def test_scalar_multiplication(test, device, dtype, register_kernels=False):
1028
- np.random.seed(123)
1029
-
1030
- tol = {
1031
- np.float16: 1.0e-2,
1032
- np.float32: 1.0e-6,
1033
- np.float64: 1.0e-8,
1034
- }.get(dtype, 0)
1035
-
1036
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1037
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1038
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1039
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1040
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1041
-
1042
- output_select_kernel = get_select_kernel(wptype)
1043
-
1044
- def check_mat_scalar_mul(
1045
- s: wp.array(dtype=wptype),
1046
- m2: wp.array(dtype=mat22),
1047
- m3: wp.array(dtype=mat33),
1048
- m4: wp.array(dtype=mat44),
1049
- m5: wp.array(dtype=mat55),
1050
- outcomponents: wp.array(dtype=wptype),
1051
- outcomponents_rightmul: wp.array(dtype=wptype),
1052
- ):
1053
- m2result = s[0] * m2[0]
1054
- m3result = s[0] * m3[0]
1055
- m4result = s[0] * m4[0]
1056
- m5result = s[0] * m5[0]
1057
-
1058
- m2resultright = m2[0] * s[0]
1059
- m3resultright = m3[0] * s[0]
1060
- m4resultright = m4[0] * s[0]
1061
- m5resultright = m5[0] * s[0]
1062
-
1063
- m2result_2 = s[0] * m2[0]
1064
- m3result_2 = s[0] * m3[0]
1065
- m4result_2 = s[0] * m4[0]
1066
- m5result_2 = s[0] * m5[0]
1067
-
1068
- m2resultright_2 = m2[0] * s[0]
1069
- m3resultright_2 = m3[0] * s[0]
1070
- m4resultright_2 = m4[0] * s[0]
1071
- m5resultright_2 = m5[0] * s[0]
1072
-
1073
- # multiply outputs by 2 so we've got something to backpropagate:
1074
- idx = 0
1075
- for i in range(2):
1076
- for j in range(2):
1077
- outcomponents[idx] = wptype(2) * m2result[i, j]
1078
- outcomponents_rightmul[idx] = wptype(2) * m2resultright[i, j]
1079
- idx = idx + 1
1080
-
1081
- for i in range(3):
1082
- for j in range(3):
1083
- outcomponents[idx] = wptype(2) * m3result[i, j]
1084
- outcomponents_rightmul[idx] = wptype(2) * m3resultright[i, j]
1085
- idx = idx + 1
1086
-
1087
- for i in range(4):
1088
- for j in range(4):
1089
- outcomponents[idx] = wptype(2) * m4result[i, j]
1090
- outcomponents_rightmul[idx] = wptype(2) * m4resultright[i, j]
1091
- idx = idx + 1
1092
-
1093
- for i in range(5):
1094
- for j in range(5):
1095
- outcomponents[idx] = wptype(2) * m5result[i, j]
1096
- outcomponents_rightmul[idx] = wptype(2) * m5resultright[i, j]
1097
- idx = idx + 1
1098
-
1099
- for i in range(2):
1100
- for j in range(2):
1101
- outcomponents[idx] = wptype(2) * m2result_2[i, j]
1102
- outcomponents_rightmul[idx] = wptype(2) * m2resultright_2[i, j]
1103
- idx = idx + 1
1104
-
1105
- for i in range(3):
1106
- for j in range(3):
1107
- outcomponents[idx] = wptype(2) * m3result_2[i, j]
1108
- outcomponents_rightmul[idx] = wptype(2) * m3resultright_2[i, j]
1109
- idx = idx + 1
1110
-
1111
- for i in range(4):
1112
- for j in range(4):
1113
- outcomponents[idx] = wptype(2) * m4result_2[i, j]
1114
- outcomponents_rightmul[idx] = wptype(2) * m4resultright_2[i, j]
1115
- idx = idx + 1
1116
-
1117
- for i in range(5):
1118
- for j in range(5):
1119
- outcomponents[idx] = wptype(2) * m5result_2[i, j]
1120
- outcomponents_rightmul[idx] = wptype(2) * m5resultright_2[i, j]
1121
- idx = idx + 1
1122
-
1123
- kernel = getkernel(check_mat_scalar_mul, suffix=dtype.__name__)
1124
-
1125
- if register_kernels:
1126
- return
1127
-
1128
- s = wp.array(randvals([1], dtype), requires_grad=True, device=device)
1129
- m2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1130
- m3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1131
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1132
- m5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1133
- outcomponents = wp.zeros(2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5), dtype=wptype, requires_grad=True, device=device)
1134
- outcomponents_rightmul = wp.zeros(
1135
- 2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5), dtype=wptype, requires_grad=True, device=device
1136
- )
1137
-
1138
- wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents, outcomponents_rightmul], device=device)
1139
-
1140
- sval = s.numpy()[0]
1141
- assert_np_equal(outcomponents.numpy()[:4], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1142
- assert_np_equal(outcomponents.numpy()[4:13], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1143
- assert_np_equal(outcomponents.numpy()[13:29], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1144
- assert_np_equal(outcomponents.numpy()[29:54], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1145
-
1146
- assert_np_equal(outcomponents_rightmul.numpy()[:4], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1147
- assert_np_equal(outcomponents_rightmul.numpy()[4:13], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1148
- assert_np_equal(outcomponents_rightmul.numpy()[13:29], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1149
- assert_np_equal(outcomponents_rightmul.numpy()[29:54], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1150
-
1151
- assert_np_equal(outcomponents.numpy()[54:58], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1152
- assert_np_equal(outcomponents.numpy()[58:67], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1153
- assert_np_equal(outcomponents.numpy()[67:83], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1154
- assert_np_equal(outcomponents.numpy()[83:108], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1155
-
1156
- assert_np_equal(outcomponents_rightmul.numpy()[54:58], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1157
- assert_np_equal(outcomponents_rightmul.numpy()[58:67], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1158
- assert_np_equal(outcomponents_rightmul.numpy()[67:83], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1159
- assert_np_equal(outcomponents_rightmul.numpy()[83:108], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1160
-
1161
- if dtype in np_float_types:
1162
- idx = 0
1163
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1164
- for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
1165
- for i in range(dim):
1166
- for j in range(dim):
1167
- # test left mul gradient:
1168
- tape = wp.Tape()
1169
- with tape:
1170
- wp.launch(
1171
- kernel,
1172
- dim=1,
1173
- inputs=[s, m2, m3, m4, m5],
1174
- outputs=[outcomponents, outcomponents_rightmul],
1175
- device=device,
1176
- )
1177
- wp.launch(
1178
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1179
- )
1180
- tape.backward(loss=out)
1181
- expectedresult = np.zeros((dim, dim), dtype=dtype)
1182
- expectedresult[i, j] = 2 * sval
1183
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
1184
- assert_np_equal(tape.gradients[s].numpy()[0], 2 * input.numpy()[0, i, j], tol=10 * tol)
1185
- tape.zero()
1186
-
1187
- # test right mul gradient:
1188
- tape = wp.Tape()
1189
- with tape:
1190
- wp.launch(
1191
- kernel,
1192
- dim=1,
1193
- inputs=[s, m2, m3, m4, m5],
1194
- outputs=[outcomponents, outcomponents_rightmul],
1195
- device=device,
1196
- )
1197
- wp.launch(
1198
- output_select_kernel,
1199
- dim=1,
1200
- inputs=[outcomponents_rightmul, idx],
1201
- outputs=[out],
1202
- device=device,
1203
- )
1204
- tape.backward(loss=out)
1205
- expectedresult = np.zeros((dim, dim), dtype=dtype)
1206
- expectedresult[i, j] = 2 * sval
1207
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
1208
- assert_np_equal(tape.gradients[s].numpy()[0], 2 * input.numpy()[0, i, j], tol=10 * tol)
1209
- tape.zero()
1210
-
1211
- idx = idx + 1
1212
-
1213
-
1214
- def test_matvec_multiplication(test, device, dtype, register_kernels=False):
1215
- np.random.seed(123)
1216
-
1217
- tol = {
1218
- np.float16: 2.0e-2,
1219
- np.float32: 5.0e-6,
1220
- np.float64: 1.0e-8,
1221
- }.get(dtype, 0)
1222
-
1223
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1224
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1225
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1226
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
1227
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1228
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1229
-
1230
- vec2 = wp.types.vector(length=2, dtype=wptype)
1231
- vec3 = wp.types.vector(length=3, dtype=wptype)
1232
- vec4 = wp.types.vector(length=4, dtype=wptype)
1233
- vec5 = wp.types.vector(length=5, dtype=wptype)
1234
-
1235
- output_select_kernel = get_select_kernel(wptype)
1236
-
1237
- def check_mat_vec_mul(
1238
- v2: wp.array(dtype=vec2),
1239
- v3: wp.array(dtype=vec3),
1240
- v4: wp.array(dtype=vec4),
1241
- v5: wp.array(dtype=vec5),
1242
- v32: wp.array(dtype=vec2),
1243
- m2: wp.array(dtype=mat22),
1244
- m3: wp.array(dtype=mat33),
1245
- m4: wp.array(dtype=mat44),
1246
- m5: wp.array(dtype=mat55),
1247
- m32: wp.array(dtype=mat32),
1248
- outcomponents: wp.array(dtype=wptype),
1249
- ):
1250
- v2result = m2[0] * v2[0]
1251
- v3result = m3[0] * v3[0]
1252
- v4result = m4[0] * v4[0]
1253
- v5result = m5[0] * v5[0]
1254
- v32result = m32[0] * v32[0]
1255
- v2result_2 = m2[0] @ v2[0]
1256
- v3result_2 = m3[0] @ v3[0]
1257
- v4result_2 = m4[0] @ v4[0]
1258
- v5result_2 = m5[0] @ v5[0]
1259
- v32result_2 = m32[0] @ v32[0]
1260
-
1261
- idx = 0
1262
-
1263
- # multiply outputs by 2 so we've got something to backpropagate:
1264
- for i in range(2):
1265
- outcomponents[idx] = wptype(2) * v2result[i]
1266
- idx = idx + 1
1267
-
1268
- for i in range(3):
1269
- outcomponents[idx] = wptype(2) * v3result[i]
1270
- idx = idx + 1
1271
-
1272
- for i in range(4):
1273
- outcomponents[idx] = wptype(2) * v4result[i]
1274
- idx = idx + 1
1275
-
1276
- for i in range(5):
1277
- outcomponents[idx] = wptype(2) * v5result[i]
1278
- idx = idx + 1
1279
-
1280
- for i in range(3):
1281
- outcomponents[idx] = wptype(2) * v32result[i]
1282
- idx = idx + 1
1283
-
1284
- for i in range(2):
1285
- outcomponents[idx] = wptype(2) * v2result_2[i]
1286
- idx = idx + 1
1287
-
1288
- for i in range(3):
1289
- outcomponents[idx] = wptype(2) * v3result_2[i]
1290
- idx = idx + 1
1291
-
1292
- for i in range(4):
1293
- outcomponents[idx] = wptype(2) * v4result_2[i]
1294
- idx = idx + 1
1295
-
1296
- for i in range(5):
1297
- outcomponents[idx] = wptype(2) * v5result_2[i]
1298
- idx = idx + 1
1299
-
1300
- for i in range(3):
1301
- outcomponents[idx] = wptype(2) * v32result_2[i]
1302
- idx = idx + 1
1303
-
1304
- kernel = getkernel(check_mat_vec_mul, suffix=dtype.__name__)
1305
-
1306
- if register_kernels:
1307
- return
1308
-
1309
- v2 = wp.array(randvals([1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1310
- v3 = wp.array(randvals([1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1311
- v4 = wp.array(randvals([1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
1312
- v5 = wp.array(randvals([1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
1313
- v32 = wp.array(randvals([1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1314
- m2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1315
- m3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1316
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1317
- m5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1318
- m32 = wp.array(randvals([1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1319
- outcomponents = wp.zeros(2 * (2 + 3 + 4 + 5 + 3), dtype=wptype, requires_grad=True, device=device)
1320
-
1321
- wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1322
-
1323
- assert_np_equal(outcomponents.numpy()[:2], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1324
- assert_np_equal(outcomponents.numpy()[2:5], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1325
- assert_np_equal(outcomponents.numpy()[5:9], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=5 * tol)
1326
- assert_np_equal(outcomponents.numpy()[9:14], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=5 * tol)
1327
- assert_np_equal(outcomponents.numpy()[14:17], 2 * np.matmul(m32.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1328
- assert_np_equal(outcomponents.numpy()[17:19], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1329
- assert_np_equal(outcomponents.numpy()[19:22], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1330
- assert_np_equal(outcomponents.numpy()[22:26], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=5 * tol)
1331
- assert_np_equal(outcomponents.numpy()[26:31], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=5 * tol)
1332
- assert_np_equal(outcomponents.numpy()[31:34], 2 * np.matmul(m32.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1333
-
1334
- if dtype in np_float_types:
1335
- idx = 0
1336
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1337
- for dim, invec, inmat in [(2, v2, m2), (3, v3, m3), (4, v4, m4), (5, v5, m5), (3, v32, m32)]:
1338
- for i in range(dim):
1339
- tape = wp.Tape()
1340
- with tape:
1341
- wp.launch(
1342
- kernel,
1343
- dim=1,
1344
- inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32],
1345
- outputs=[outcomponents],
1346
- device=device,
1347
- )
1348
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1349
- tape.backward(loss=out)
1350
-
1351
- assert_np_equal(tape.gradients[invec].numpy()[0], 2 * inmat.numpy()[0, i, :], tol=2 * tol)
1352
- expectedresult = np.zeros(inmat.dtype._shape_, dtype=dtype)
1353
- expectedresult[i, :] = 2 * invec.numpy()[0]
1354
- assert_np_equal(tape.gradients[inmat].numpy()[0], expectedresult, tol=2 * tol)
1355
-
1356
- tape.zero()
1357
-
1358
- idx = idx + 1
1359
-
1360
-
1361
- def test_matmat_multiplication(test, device, dtype, register_kernels=False):
1362
- np.random.seed(123)
1363
-
1364
- tol = {
1365
- np.float16: 2.0e-2,
1366
- np.float32: 5.0e-6,
1367
- np.float64: 1.0e-8,
1368
- }.get(dtype, 0)
1369
-
1370
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1371
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1372
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1373
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
1374
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1375
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1376
-
1377
- output_select_kernel = get_select_kernel(wptype)
1378
-
1379
- def check_mat_mat_mul(
1380
- a2: wp.array(dtype=mat22),
1381
- a3: wp.array(dtype=mat33),
1382
- a4: wp.array(dtype=mat44),
1383
- a5: wp.array(dtype=mat55),
1384
- a32: wp.array(dtype=mat32),
1385
- b2: wp.array(dtype=mat22),
1386
- b3: wp.array(dtype=mat33),
1387
- b4: wp.array(dtype=mat44),
1388
- b5: wp.array(dtype=mat55),
1389
- b32: wp.array(dtype=mat32),
1390
- outcomponents: wp.array(dtype=wptype),
1391
- ):
1392
- c2result = b2[0] * a2[0]
1393
- c3result = b3[0] * a3[0]
1394
- c4result = b4[0] * a4[0]
1395
- c5result = b5[0] * a5[0]
1396
- c32result = b32[0] * a2[0]
1397
- c32result2 = b3[0] * a32[0]
1398
- c2result_2 = b2[0] @ a2[0]
1399
- c3result_2 = b3[0] @ a3[0]
1400
- c4result_2 = b4[0] @ a4[0]
1401
- c5result_2 = b5[0] @ a5[0]
1402
- c32result_2 = b32[0] @ a2[0]
1403
- c32result2_2 = b3[0] @ a32[0]
1404
-
1405
- # multiply outputs by 2 so we've got something to backpropagate:
1406
- idx = 0
1407
- for i in range(2):
1408
- for j in range(2):
1409
- outcomponents[idx] = wptype(2) * c2result[i, j]
1410
- idx = idx + 1
1411
-
1412
- for i in range(3):
1413
- for j in range(3):
1414
- outcomponents[idx] = wptype(2) * c3result[i, j]
1415
- idx = idx + 1
1416
-
1417
- for i in range(4):
1418
- for j in range(4):
1419
- outcomponents[idx] = wptype(2) * c4result[i, j]
1420
- idx = idx + 1
1421
-
1422
- for i in range(5):
1423
- for j in range(5):
1424
- outcomponents[idx] = wptype(2) * c5result[i, j]
1425
- idx = idx + 1
1426
-
1427
- for i in range(3):
1428
- for j in range(2):
1429
- outcomponents[idx] = wptype(2) * c32result[i, j]
1430
- idx = idx + 1
1431
-
1432
- for i in range(3):
1433
- for j in range(2):
1434
- outcomponents[idx] = wptype(2) * c32result2[i, j]
1435
- idx = idx + 1
1436
-
1437
- for i in range(2):
1438
- for j in range(2):
1439
- outcomponents[idx] = wptype(2) * c2result_2[i, j]
1440
- idx = idx + 1
1441
-
1442
- for i in range(3):
1443
- for j in range(3):
1444
- outcomponents[idx] = wptype(2) * c3result_2[i, j]
1445
- idx = idx + 1
1446
-
1447
- for i in range(4):
1448
- for j in range(4):
1449
- outcomponents[idx] = wptype(2) * c4result_2[i, j]
1450
- idx = idx + 1
1451
-
1452
- for i in range(5):
1453
- for j in range(5):
1454
- outcomponents[idx] = wptype(2) * c5result_2[i, j]
1455
- idx = idx + 1
1456
-
1457
- for i in range(3):
1458
- for j in range(2):
1459
- outcomponents[idx] = wptype(2) * c32result_2[i, j]
1460
- idx = idx + 1
1461
-
1462
- for i in range(3):
1463
- for j in range(2):
1464
- outcomponents[idx] = wptype(2) * c32result2_2[i, j]
1465
- idx = idx + 1
1466
-
1467
- kernel = getkernel(check_mat_mat_mul, suffix=dtype.__name__)
1468
-
1469
- if register_kernels:
1470
- return
1471
-
1472
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1473
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1474
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1475
- v5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1476
- v32 = wp.array(randvals([1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1477
- m2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1478
- m3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1479
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1480
- m5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1481
- m32 = wp.array(randvals([1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1482
- outcomponents = wp.zeros(
1483
- 2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2 + 3 * 2), dtype=wptype, requires_grad=True, device=device
1484
- )
1485
-
1486
- wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1487
-
1488
- assert_np_equal(outcomponents.numpy()[:4], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1489
- assert_np_equal(outcomponents.numpy()[4:13], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1490
- assert_np_equal(outcomponents.numpy()[13:29], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=2 * tol)
1491
- assert_np_equal(outcomponents.numpy()[29:54], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=10 * tol)
1492
- assert_np_equal(outcomponents.numpy()[54:60], 2 * np.matmul(m32.numpy()[0], v2.numpy()[0]), tol=5 * tol)
1493
- assert_np_equal(outcomponents.numpy()[60:66], 2 * np.matmul(m3.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1494
- assert_np_equal(outcomponents.numpy()[66:70], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1495
- assert_np_equal(outcomponents.numpy()[70:79], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1496
- assert_np_equal(outcomponents.numpy()[79:95], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=2 * tol)
1497
- assert_np_equal(outcomponents.numpy()[95:120], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=10 * tol)
1498
- assert_np_equal(outcomponents.numpy()[120:126], 2 * np.matmul(m32.numpy()[0], v2.numpy()[0]), tol=5 * tol)
1499
- assert_np_equal(outcomponents.numpy()[126:132], 2 * np.matmul(m3.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1500
-
1501
- if dtype in np_float_types:
1502
- idx = 0
1503
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1504
- for v, m in [(v2, m2), (v3, m3), (v4, m4), (v5, m5), (v2, m32), (v32, m3)]:
1505
- rows, cols = m.dtype._shape_[0], v.dtype._shape_[1]
1506
- for i in range(rows):
1507
- for j in range(cols):
1508
- tape = wp.Tape()
1509
- with tape:
1510
- wp.launch(
1511
- kernel,
1512
- dim=1,
1513
- inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32],
1514
- outputs=[outcomponents],
1515
- device=device,
1516
- )
1517
- wp.launch(
1518
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1519
- )
1520
- tape.backward(loss=out)
1521
-
1522
- expected = np.zeros(v.dtype._shape_, dtype=dtype)
1523
- expected[:, j] = 2 * m.numpy()[0, i, :]
1524
- assert_np_equal(tape.gradients[v].numpy()[0], expected, tol=10 * tol)
1525
-
1526
- expected = np.zeros(m.dtype._shape_, dtype=dtype)
1527
- expected[i, :] = 2 * v.numpy()[0, :, j]
1528
- assert_np_equal(tape.gradients[m].numpy()[0], expected, tol=10 * tol)
1529
-
1530
- tape.zero()
1531
- idx = idx + 1
1532
-
1533
-
1534
- def test_cw_multiplication(test, device, dtype, register_kernels=False):
1535
- np.random.seed(123)
1536
-
1537
- tol = {
1538
- np.float16: 5.0e-2,
1539
- np.float32: 1.0e-6,
1540
- np.float64: 1.0e-8,
1541
- }.get(dtype, 0)
1542
-
1543
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1544
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1545
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1546
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1547
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1548
-
1549
- output_select_kernel = get_select_kernel(wptype)
1550
-
1551
- def check_mat_cw_mul(
1552
- s2: wp.array(dtype=mat22),
1553
- s3: wp.array(dtype=mat33),
1554
- s4: wp.array(dtype=mat44),
1555
- s5: wp.array(dtype=mat55),
1556
- v2: wp.array(dtype=mat22),
1557
- v3: wp.array(dtype=mat33),
1558
- v4: wp.array(dtype=mat44),
1559
- v5: wp.array(dtype=mat55),
1560
- outcomponents: wp.array(dtype=wptype),
1561
- ):
1562
- v2result = wptype(2) * wp.cw_mul(v2[0], s2[0])
1563
- v3result = wptype(2) * wp.cw_mul(v3[0], s3[0])
1564
- v4result = wptype(2) * wp.cw_mul(v4[0], s4[0])
1565
- v5result = wptype(2) * wp.cw_mul(v5[0], s5[0])
1566
-
1567
- # multiply outputs by 2 so we've got something to backpropagate:
1568
- idx = 0
1569
- for i in range(2):
1570
- for j in range(2):
1571
- outcomponents[idx] = v2result[i, j]
1572
- idx = idx + 1
1573
-
1574
- for i in range(3):
1575
- for j in range(3):
1576
- outcomponents[idx] = v3result[i, j]
1577
- idx = idx + 1
1578
-
1579
- for i in range(4):
1580
- for j in range(4):
1581
- outcomponents[idx] = v4result[i, j]
1582
- idx = idx + 1
1583
-
1584
- for i in range(5):
1585
- for j in range(5):
1586
- outcomponents[idx] = v5result[i, j]
1587
- idx = idx + 1
1588
-
1589
- kernel = getkernel(check_mat_cw_mul, suffix=dtype.__name__)
1590
-
1591
- if register_kernels:
1592
- return
1593
-
1594
- s2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1595
- s3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1596
- s4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1597
- s5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1598
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1599
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1600
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1601
- v5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1602
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
1603
-
1604
- wp.launch(
1605
- kernel,
1606
- dim=1,
1607
- inputs=[
1608
- s2,
1609
- s3,
1610
- s4,
1611
- s5,
1612
- v2,
1613
- v3,
1614
- v4,
1615
- v5,
1616
- ],
1617
- outputs=[outcomponents],
1618
- device=device,
1619
- )
1620
-
1621
- assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() * s2.numpy()).reshape(-1), tol=50 * tol)
1622
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() * s3.numpy()).reshape(-1), tol=50 * tol)
1623
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() * s4.numpy()).reshape(-1), tol=50 * tol)
1624
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() * s5.numpy()).reshape(-1), tol=50 * tol)
1625
-
1626
- if dtype in np_float_types:
1627
- idx = 0
1628
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1629
- for dim, in1, in2 in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
1630
- for i in range(dim):
1631
- for j in range(dim):
1632
- tape = wp.Tape()
1633
- with tape:
1634
- wp.launch(
1635
- kernel,
1636
- dim=1,
1637
- inputs=[
1638
- s2,
1639
- s3,
1640
- s4,
1641
- s5,
1642
- v2,
1643
- v3,
1644
- v4,
1645
- v5,
1646
- ],
1647
- outputs=[outcomponents],
1648
- device=device,
1649
- )
1650
- wp.launch(
1651
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1652
- )
1653
- tape.backward(loss=out)
1654
- expectedresult = np.zeros((dim, dim), dtype=dtype)
1655
- expectedresult[i, j] = 2 * in1.numpy()[0][i, j]
1656
- assert_np_equal(tape.gradients[in2].numpy()[0], expectedresult, tol=5 * tol)
1657
- expectedresult[i, j] = 2 * in2.numpy()[0][i, j]
1658
- assert_np_equal(tape.gradients[in1].numpy()[0], expectedresult, tol=5 * tol)
1659
- tape.zero()
1660
-
1661
- idx = idx + 1
1662
-
1663
-
1664
- def test_cw_division(test, device, dtype, register_kernels=False):
1665
- np.random.seed(123)
1666
-
1667
- tol = {
1668
- np.float16: 1.0e-2,
1669
- np.float32: 1.0e-6,
1670
- np.float64: 1.0e-8,
1671
- }.get(dtype, 0)
1672
-
1673
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1674
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1675
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1676
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1677
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1678
-
1679
- output_select_kernel = get_select_kernel(wptype)
1680
-
1681
- def check_mat_cw_div(
1682
- s2: wp.array(dtype=mat22),
1683
- s3: wp.array(dtype=mat33),
1684
- s4: wp.array(dtype=mat44),
1685
- s5: wp.array(dtype=mat55),
1686
- v2: wp.array(dtype=mat22),
1687
- v3: wp.array(dtype=mat33),
1688
- v4: wp.array(dtype=mat44),
1689
- v5: wp.array(dtype=mat55),
1690
- outcomponents: wp.array(dtype=wptype),
1691
- ):
1692
- v2result = wptype(2) * wp.cw_div(v2[0], s2[0])
1693
- v3result = wptype(2) * wp.cw_div(v3[0], s3[0])
1694
- v4result = wptype(2) * wp.cw_div(v4[0], s4[0])
1695
- v5result = wptype(2) * wp.cw_div(v5[0], s5[0])
1696
-
1697
- # multiply outputs by 2 so we've got something to backpropagate:
1698
- idx = 0
1699
- for i in range(2):
1700
- for j in range(2):
1701
- outcomponents[idx] = v2result[i, j]
1702
- idx = idx + 1
1703
-
1704
- for i in range(3):
1705
- for j in range(3):
1706
- outcomponents[idx] = v3result[i, j]
1707
- idx = idx + 1
1708
-
1709
- for i in range(4):
1710
- for j in range(4):
1711
- outcomponents[idx] = v4result[i, j]
1712
- idx = idx + 1
1713
-
1714
- for i in range(5):
1715
- for j in range(5):
1716
- outcomponents[idx] = v5result[i, j]
1717
- idx = idx + 1
1718
-
1719
- kernel = getkernel(check_mat_cw_div, suffix=dtype.__name__)
1720
-
1721
- if register_kernels:
1722
- return
1723
-
1724
- s2 = randvals([1, 2, 2], dtype)
1725
- s3 = randvals([1, 3, 3], dtype)
1726
- s4 = randvals([1, 4, 4], dtype)
1727
- s5 = randvals([1, 5, 5], dtype)
1728
-
1729
- # set denominators to 1 if their magnitudes are small
1730
- # to prevent divide by zero, or overflows if we're testing
1731
- # float16:
1732
- s2[np.abs(s2) < 1.0e-2] = 1
1733
- s3[np.abs(s3) < 1.0e-2] = 1
1734
- s4[np.abs(s4) < 1.0e-2] = 1
1735
- s5[np.abs(s5) < 1.0e-2] = 1
1736
-
1737
- s2 = wp.array(s2, dtype=mat22, requires_grad=True, device=device)
1738
- s3 = wp.array(s3, dtype=mat33, requires_grad=True, device=device)
1739
- s4 = wp.array(s4, dtype=mat44, requires_grad=True, device=device)
1740
- s5 = wp.array(s5, dtype=mat55, requires_grad=True, device=device)
1741
-
1742
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1743
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1744
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1745
- v5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1746
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
1747
-
1748
- wp.launch(
1749
- kernel,
1750
- dim=1,
1751
- inputs=[
1752
- s2,
1753
- s3,
1754
- s4,
1755
- s5,
1756
- v2,
1757
- v3,
1758
- v4,
1759
- v5,
1760
- ],
1761
- outputs=[outcomponents],
1762
- device=device,
1763
- )
1764
-
1765
- if dtype in np_float_types:
1766
- assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() / s2.numpy()).reshape(-1), tol=50 * tol)
1767
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() / s3.numpy()).reshape(-1), tol=50 * tol)
1768
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() / s4.numpy()).reshape(-1), tol=50 * tol)
1769
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() / s5.numpy()).reshape(-1), tol=50 * tol)
1770
- else:
1771
- assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() // s2.numpy()).reshape(-1), tol=50 * tol)
1772
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() // s3.numpy()).reshape(-1), tol=50 * tol)
1773
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() // s4.numpy()).reshape(-1), tol=50 * tol)
1774
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() // s5.numpy()).reshape(-1), tol=50 * tol)
1775
-
1776
- if dtype in np_float_types:
1777
- idx = 0
1778
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1779
- for dim, s, v in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
1780
- for i in range(dim):
1781
- for j in range(dim):
1782
- tape = wp.Tape()
1783
- with tape:
1784
- wp.launch(
1785
- kernel,
1786
- dim=1,
1787
- inputs=[
1788
- s2,
1789
- s3,
1790
- s4,
1791
- s5,
1792
- v2,
1793
- v3,
1794
- v4,
1795
- v5,
1796
- ],
1797
- outputs=[outcomponents],
1798
- device=device,
1799
- )
1800
- wp.launch(
1801
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1802
- )
1803
- tape.backward(loss=out)
1804
-
1805
- # y = v/s
1806
- # dy/dv = 1.0/s
1807
- # dy/ds = -v/s^2
1808
-
1809
- expectedresult = np.zeros((dim, dim), dtype=dtype)
1810
- expectedresult[i, j] = 2.0 / (s.numpy()[0, i, j])
1811
- assert_np_equal(tape.gradients[v].numpy()[0], expectedresult, tol=50 * tol)
1812
- expectedresult[i, j] = -2.0 * v.numpy()[0, i, j] / (s.numpy()[0, i, j] ** 2)
1813
- assert_np_equal(
1814
- tape.gradients[s].numpy()[0], expectedresult, tol=abs(outcomponents.numpy()[idx]) * 50 * tol
1815
- )
1816
- tape.zero()
1817
-
1818
- idx = idx + 1
1819
-
1820
-
1821
- def test_outer_product(test, device, dtype, register_kernels=False):
1822
- np.random.seed(123)
1823
-
1824
- tol = {
1825
- np.float16: 5.0e-3,
1826
- np.float32: 1.0e-6,
1827
- np.float64: 1.0e-8,
1828
- }.get(dtype, 0)
1829
-
1830
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1831
- vec2 = wp.types.vector(length=2, dtype=wptype)
1832
- vec3 = wp.types.vector(length=3, dtype=wptype)
1833
- vec4 = wp.types.vector(length=4, dtype=wptype)
1834
- vec5 = wp.types.vector(length=5, dtype=wptype)
1835
-
1836
- output_select_kernel = get_select_kernel(wptype)
1837
-
1838
- def check_mat_outer_product(
1839
- s2: wp.array(dtype=vec2),
1840
- s3: wp.array(dtype=vec3),
1841
- s4: wp.array(dtype=vec4),
1842
- s5: wp.array(dtype=vec5),
1843
- v2: wp.array(dtype=vec2),
1844
- v3: wp.array(dtype=vec3),
1845
- v4: wp.array(dtype=vec4),
1846
- v5: wp.array(dtype=vec5),
1847
- outcomponents: wp.array(dtype=wptype),
1848
- ):
1849
- m22result = wptype(2) * wp.outer(s2[0], v2[0])
1850
- m33result = wptype(2) * wp.outer(s3[0], v3[0])
1851
- m44result = wptype(2) * wp.outer(s4[0], v4[0])
1852
- m55result = wptype(2) * wp.outer(s5[0], v5[0])
1853
- m25result = wptype(2) * wp.outer(s2[0], v5[0])
1854
-
1855
- # multiply outputs by 2 so we've got something to backpropagate:
1856
- idx = 0
1857
- for i in range(2):
1858
- for j in range(2):
1859
- outcomponents[idx] = m22result[i, j]
1860
- idx = idx + 1
1861
-
1862
- for i in range(3):
1863
- for j in range(3):
1864
- outcomponents[idx] = m33result[i, j]
1865
- idx = idx + 1
1866
-
1867
- for i in range(4):
1868
- for j in range(4):
1869
- outcomponents[idx] = m44result[i, j]
1870
- idx = idx + 1
1871
-
1872
- for i in range(5):
1873
- for j in range(5):
1874
- outcomponents[idx] = m55result[i, j]
1875
- idx = idx + 1
1876
-
1877
- for i in range(2):
1878
- for j in range(5):
1879
- outcomponents[idx] = m25result[i, j]
1880
- idx = idx + 1
1881
-
1882
- kernel = getkernel(check_mat_outer_product, suffix=dtype.__name__)
1883
-
1884
- if register_kernels:
1885
- return
1886
-
1887
- s2 = wp.array(randvals([1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1888
- s3 = wp.array(randvals([1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1889
- s4 = wp.array(randvals([1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
1890
- s5 = wp.array(randvals([1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
1891
- v2 = wp.array(randvals([1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1892
- v3 = wp.array(randvals([1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1893
- v4 = wp.array(randvals([1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
1894
- v5 = wp.array(randvals([1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
1895
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 2 * 5, dtype=wptype, requires_grad=True, device=device)
1896
-
1897
- wp.launch(kernel, dim=1, inputs=[s2, s3, s4, s5, v2, v3, v4, v5], outputs=[outcomponents], device=device)
1898
-
1899
- assert_np_equal(outcomponents.numpy()[:4], 2 * s2.numpy()[0, :, None] * v2.numpy()[0, None, :], tol=tol)
1900
- assert_np_equal(outcomponents.numpy()[4:13], 2 * s3.numpy()[0, :, None] * v3.numpy()[0, None, :], tol=10 * tol)
1901
- assert_np_equal(outcomponents.numpy()[13:29], 2 * s4.numpy()[0, :, None] * v4.numpy()[0, None, :], tol=10 * tol)
1902
- assert_np_equal(outcomponents.numpy()[29:54], 2 * s5.numpy()[0, :, None] * v5.numpy()[0, None, :], tol=10 * tol)
1903
- assert_np_equal(outcomponents.numpy()[54:], 2 * s2.numpy()[0, :, None] * v5.numpy()[0, None, :], tol=10 * tol)
1904
-
1905
- if dtype in np_float_types:
1906
- idx = 0
1907
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1908
- for s, v in [(s2, v2), (s3, v3), (s4, v4), (s5, v5), (s2, v5)]:
1909
- rows = s.dtype._length_
1910
- cols = v.dtype._length_
1911
- for i in range(rows):
1912
- for j in range(cols):
1913
- tape = wp.Tape()
1914
- with tape:
1915
- wp.launch(
1916
- kernel,
1917
- dim=1,
1918
- inputs=[
1919
- s2,
1920
- s3,
1921
- s4,
1922
- s5,
1923
- v2,
1924
- v3,
1925
- v4,
1926
- v5,
1927
- ],
1928
- outputs=[outcomponents],
1929
- device=device,
1930
- )
1931
- wp.launch(
1932
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1933
- )
1934
- tape.backward(loss=out)
1935
-
1936
- # this component's gonna be s_i * v_j, so its s gradient is gonna be nozero
1937
- # at the ith component and its v gradient will be nonzero at the jth component:
1938
-
1939
- expectedresult = np.zeros((rows), dtype=dtype)
1940
- expectedresult[i] = 2 * v.numpy()[0, j]
1941
- assert_np_equal(tape.gradients[s].numpy()[0], expectedresult, tol=10 * tol)
1942
-
1943
- expectedresult = np.zeros((cols), dtype=dtype)
1944
- expectedresult[j] = 2 * s.numpy()[0, i]
1945
- assert_np_equal(tape.gradients[v].numpy()[0], expectedresult, tol=10 * tol)
1946
- tape.zero()
1947
-
1948
- idx = idx + 1
1949
-
1950
-
1951
- def test_scalar_division(test, device, dtype, register_kernels=False):
1952
- np.random.seed(123)
1953
-
1954
- tol = {
1955
- np.float16: 1.0e-2,
1956
- np.float32: 1.0e-6,
1957
- np.float64: 1.0e-8,
1958
- }.get(dtype, 0)
1959
-
1960
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1961
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1962
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1963
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1964
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1965
-
1966
- output_select_kernel = get_select_kernel(wptype)
1967
-
1968
- def check_mat_scalar_div(
1969
- s: wp.array(dtype=wptype),
1970
- m2: wp.array(dtype=mat22),
1971
- m3: wp.array(dtype=mat33),
1972
- m4: wp.array(dtype=mat44),
1973
- m5: wp.array(dtype=mat55),
1974
- outcomponents: wp.array(dtype=wptype),
1975
- ):
1976
- m2result = m2[0] / s[0]
1977
- m3result = m3[0] / s[0]
1978
- m4result = m4[0] / s[0]
1979
- m5result = m5[0] / s[0]
1980
-
1981
- # multiply outputs by 2 so we've got something to backpropagate:
1982
- idx = 0
1983
- for i in range(2):
1984
- for j in range(2):
1985
- outcomponents[idx] = wptype(2) * m2result[i, j]
1986
- idx = idx + 1
1987
-
1988
- for i in range(3):
1989
- for j in range(3):
1990
- outcomponents[idx] = wptype(2) * m3result[i, j]
1991
- idx = idx + 1
1992
-
1993
- for i in range(4):
1994
- for j in range(4):
1995
- outcomponents[idx] = wptype(2) * m4result[i, j]
1996
- idx = idx + 1
1997
-
1998
- for i in range(5):
1999
- for j in range(5):
2000
- outcomponents[idx] = wptype(2) * m5result[i, j]
2001
- idx = idx + 1
2002
-
2003
- kernel = getkernel(check_mat_scalar_div, suffix=dtype.__name__)
2004
-
2005
- if register_kernels:
2006
- return
2007
-
2008
- s = wp.array(randvals([1], dtype), requires_grad=True, device=device)
2009
- m2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2010
- m3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2011
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2012
- m5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2013
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
2014
-
2015
- wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents], device=device)
2016
-
2017
- sval = s.numpy()[0]
2018
- if dtype in np_float_types:
2019
- assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy().reshape(-1) / sval, tol=tol)
2020
- assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy().reshape(-1) / sval, tol=10 * tol)
2021
- assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy().reshape(-1) / sval, tol=10 * tol)
2022
- assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy().reshape(-1) / sval, tol=10 * tol)
2023
- else:
2024
- assert_np_equal(outcomponents.numpy()[:4], 2 * (m2.numpy().reshape(-1) // sval), tol=tol)
2025
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (m3.numpy().reshape(-1) // sval), tol=10 * tol)
2026
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (m4.numpy().reshape(-1) // sval), tol=10 * tol)
2027
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (m5.numpy().reshape(-1) // sval), tol=10 * tol)
2028
-
2029
- if dtype in np_float_types:
2030
- idx = 0
2031
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2032
- for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
2033
- for i in range(dim):
2034
- for j in range(dim):
2035
- tape = wp.Tape()
2036
- with tape:
2037
- wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents], device=device)
2038
- wp.launch(
2039
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2040
- )
2041
- tape.backward(loss=out)
2042
- expectedresult = np.zeros((dim, dim), dtype=dtype)
2043
- expectedresult[i, j] = 2.0 / sval
2044
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
2045
- assert_np_equal(
2046
- tape.gradients[s].numpy()[0], -2 * input.numpy()[0, i, j] / (sval * sval), tol=10 * tol
2047
- )
2048
- tape.zero()
2049
-
2050
- idx = idx + 1
2051
-
2052
-
2053
- def test_addition(test, device, dtype, register_kernels=False):
2054
- np.random.seed(123)
2055
-
2056
- tol = {
2057
- np.float16: 2.0e-2,
2058
- np.float32: 5.0e-6,
2059
- np.float64: 1.0e-8,
2060
- }.get(dtype, 0)
2061
-
2062
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2063
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2064
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2065
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2066
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2067
-
2068
- output_select_kernel = get_select_kernel(wptype)
2069
-
2070
- def check_mat_add(
2071
- s2: wp.array(dtype=mat22),
2072
- s3: wp.array(dtype=mat33),
2073
- s4: wp.array(dtype=mat44),
2074
- s5: wp.array(dtype=mat55),
2075
- v2: wp.array(dtype=mat22),
2076
- v3: wp.array(dtype=mat33),
2077
- v4: wp.array(dtype=mat44),
2078
- v5: wp.array(dtype=mat55),
383
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
384
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
385
+
386
+ output_select_kernel = get_select_kernel(wptype)
387
+
388
+ def check_mat_negation(
389
+ m2: wp.array(dtype=mat22),
390
+ m3: wp.array(dtype=mat33),
391
+ m4: wp.array(dtype=mat44),
392
+ m5: wp.array(dtype=mat55),
2079
393
  outcomponents: wp.array(dtype=wptype),
2080
394
  ):
2081
- v2result = v2[0] + s2[0]
2082
- v3result = v3[0] + s3[0]
2083
- v4result = v4[0] + s4[0]
2084
- v5result = v5[0] + s5[0]
395
+ mat2 = -m2[0]
396
+ mat3 = -m3[0]
397
+ mat4 = -m4[0]
398
+ mat5 = -m5[0]
2085
399
 
2086
400
  # multiply outputs by 2 so we've got something to backpropagate:
2087
401
  idx = 0
2088
402
  for i in range(2):
2089
403
  for j in range(2):
2090
- outcomponents[idx] = wptype(2) * v2result[i, j]
404
+ outcomponents[idx] = wptype(2) * mat2[i, j]
2091
405
  idx = idx + 1
2092
406
 
2093
407
  for i in range(3):
2094
408
  for j in range(3):
2095
- outcomponents[idx] = wptype(2) * v3result[i, j]
409
+ outcomponents[idx] = wptype(2) * mat3[i, j]
2096
410
  idx = idx + 1
2097
411
 
2098
412
  for i in range(4):
2099
413
  for j in range(4):
2100
- outcomponents[idx] = wptype(2) * v4result[i, j]
414
+ outcomponents[idx] = wptype(2) * mat4[i, j]
2101
415
  idx = idx + 1
2102
416
 
2103
417
  for i in range(5):
2104
418
  for j in range(5):
2105
- outcomponents[idx] = wptype(2) * v5result[i, j]
419
+ outcomponents[idx] = wptype(2) * mat5[i, j]
2106
420
  idx = idx + 1
2107
421
 
2108
- kernel = getkernel(check_mat_add, suffix=dtype.__name__)
422
+ kernel = getkernel(check_mat_negation, suffix=dtype.__name__)
2109
423
 
2110
424
  if register_kernels:
2111
425
  return
2112
426
 
2113
- s2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2114
- s3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2115
- s4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2116
- s5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2117
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2118
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2119
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2120
- v5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
427
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
428
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
429
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
430
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2121
431
  outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
2122
432
 
2123
- wp.launch(
2124
- kernel,
2125
- dim=1,
2126
- inputs=[
2127
- s2,
2128
- s3,
2129
- s4,
2130
- s5,
2131
- v2,
2132
- v3,
2133
- v4,
2134
- v5,
2135
- ],
2136
- outputs=[outcomponents],
2137
- device=device,
2138
- )
433
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
2139
434
 
2140
- assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() + s2.numpy()).reshape(-1), tol=tol)
2141
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() + s3.numpy()).reshape(-1), tol=tol)
2142
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() + s4.numpy()).reshape(-1), tol=tol)
2143
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() + s5.numpy()).reshape(-1), tol=tol)
435
+ assert_np_equal(outcomponents.numpy()[:4], -2 * m2.numpy().reshape(-1), tol=tol)
436
+ assert_np_equal(outcomponents.numpy()[4:13], -2 * m3.numpy().reshape(-1), tol=tol)
437
+ assert_np_equal(outcomponents.numpy()[13:29], -2 * m4.numpy().reshape(-1), tol=tol)
438
+ assert_np_equal(outcomponents.numpy()[29:54], -2 * m5.numpy().reshape(-1), tol=tol)
2144
439
 
2145
440
  if dtype in np_float_types:
2146
441
  idx = 0
2147
442
  out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2148
- for dim, in1, in2 in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
443
+ for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
2149
444
  for i in range(dim):
2150
445
  for j in range(dim):
2151
446
  tape = wp.Tape()
2152
447
  with tape:
2153
- wp.launch(
2154
- kernel,
2155
- dim=1,
2156
- inputs=[
2157
- s2,
2158
- s3,
2159
- s4,
2160
- s5,
2161
- v2,
2162
- v3,
2163
- v4,
2164
- v5,
2165
- ],
2166
- outputs=[outcomponents],
2167
- device=device,
2168
- )
448
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
2169
449
  wp.launch(
2170
450
  output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2171
451
  )
2172
452
  tape.backward(loss=out)
2173
453
  expectedresult = np.zeros((dim, dim), dtype=dtype)
2174
- expectedresult[i, j] = 2
2175
- assert_np_equal(tape.gradients[in2].numpy()[0], expectedresult, tol=10 * tol)
2176
- expectedresult[i, j] = 2
2177
- assert_np_equal(tape.gradients[in1].numpy()[0], expectedresult, tol=10 * tol)
454
+ expectedresult[i, j] = -2
455
+ assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
2178
456
  tape.zero()
2179
-
2180
457
  idx = idx + 1
2181
458
 
2182
459
 
2183
460
  def test_subtraction(test, device, dtype, register_kernels=False):
2184
- np.random.seed(123)
461
+ rng = np.random.default_rng(123)
2185
462
 
2186
463
  tol = {
2187
464
  np.float16: 5.0e-3,
@@ -2240,14 +517,14 @@ def test_subtraction(test, device, dtype, register_kernels=False):
2240
517
  if register_kernels:
2241
518
  return
2242
519
 
2243
- s2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2244
- s3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2245
- s4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2246
- s5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2247
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2248
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2249
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2250
- v5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
520
+ s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
521
+ s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
522
+ s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
523
+ s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
524
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
525
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
526
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
527
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2251
528
  outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
2252
529
 
2253
530
  wp.launch(
@@ -2310,131 +587,8 @@ def test_subtraction(test, device, dtype, register_kernels=False):
2310
587
  idx = idx + 1
2311
588
 
2312
589
 
2313
- def test_ddot(test, device, dtype, register_kernels=False):
2314
- np.random.seed(123)
2315
-
2316
- tol = {
2317
- np.float16: 5.0e-3,
2318
- np.float32: 1.0e-6,
2319
- np.float64: 1.0e-8,
2320
- }.get(dtype, 0)
2321
-
2322
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2323
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2324
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2325
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2326
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2327
-
2328
- def check_mat_dot(
2329
- s2: wp.array(dtype=mat22),
2330
- s3: wp.array(dtype=mat33),
2331
- s4: wp.array(dtype=mat44),
2332
- s5: wp.array(dtype=mat55),
2333
- v2: wp.array(dtype=mat22),
2334
- v3: wp.array(dtype=mat33),
2335
- v4: wp.array(dtype=mat44),
2336
- v5: wp.array(dtype=mat55),
2337
- dot2: wp.array(dtype=wptype),
2338
- dot3: wp.array(dtype=wptype),
2339
- dot4: wp.array(dtype=wptype),
2340
- dot5: wp.array(dtype=wptype),
2341
- ):
2342
- # multiply outputs by 2 so we've got something to backpropagate:
2343
- dot2[0] = wptype(2) * wp.ddot(v2[0], s2[0])
2344
- dot3[0] = wptype(2) * wp.ddot(v3[0], s3[0])
2345
- dot4[0] = wptype(2) * wp.ddot(v4[0], s4[0])
2346
- dot5[0] = wptype(2) * wp.ddot(v5[0], s5[0])
2347
-
2348
- kernel = getkernel(check_mat_dot, suffix=dtype.__name__)
2349
-
2350
- if register_kernels:
2351
- return
2352
-
2353
- s2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2354
- s3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2355
- s4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2356
- s5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2357
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2358
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2359
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2360
- v5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2361
- dot2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2362
- dot3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2363
- dot4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2364
- dot5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2365
-
2366
- tape = wp.Tape()
2367
- with tape:
2368
- wp.launch(
2369
- kernel,
2370
- dim=1,
2371
- inputs=[
2372
- s2,
2373
- s3,
2374
- s4,
2375
- s5,
2376
- v2,
2377
- v3,
2378
- v4,
2379
- v5,
2380
- ],
2381
- outputs=[dot2, dot3, dot4, dot5],
2382
- device=device,
2383
- )
2384
-
2385
- assert_np_equal(dot2.numpy()[0], 2 * (v2.numpy() * s2.numpy()).sum(), tol=10 * tol)
2386
- assert_np_equal(dot3.numpy()[0], 2 * (v3.numpy() * s3.numpy()).sum(), tol=10 * tol)
2387
- assert_np_equal(dot4.numpy()[0], 2 * (v4.numpy() * s4.numpy()).sum(), tol=50 * tol)
2388
- assert_np_equal(dot5.numpy()[0], 2 * (v5.numpy() * s5.numpy()).sum(), tol=200 * tol)
2389
-
2390
- if dtype in np_float_types:
2391
- tape.backward(loss=dot2)
2392
- sgrads = tape.gradients[s2].numpy()[0]
2393
- expected_grads = 2.0 * v2.numpy()[0]
2394
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2395
-
2396
- vgrads = tape.gradients[v2].numpy()[0]
2397
- expected_grads = 2.0 * s2.numpy()[0]
2398
- assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2399
-
2400
- tape.zero()
2401
-
2402
- tape.backward(loss=dot3)
2403
- sgrads = tape.gradients[s3].numpy()[0]
2404
- expected_grads = 2.0 * v3.numpy()[0]
2405
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2406
-
2407
- vgrads = tape.gradients[v3].numpy()[0]
2408
- expected_grads = 2.0 * s3.numpy()[0]
2409
- assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2410
-
2411
- tape.zero()
2412
-
2413
- tape.backward(loss=dot4)
2414
- sgrads = tape.gradients[s4].numpy()[0]
2415
- expected_grads = 2.0 * v4.numpy()[0]
2416
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2417
-
2418
- vgrads = tape.gradients[v4].numpy()[0]
2419
- expected_grads = 2.0 * s4.numpy()[0]
2420
- assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2421
-
2422
- tape.zero()
2423
-
2424
- tape.backward(loss=dot5)
2425
- sgrads = tape.gradients[s5].numpy()[0]
2426
- expected_grads = 2.0 * v5.numpy()[0]
2427
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2428
-
2429
- vgrads = tape.gradients[v5].numpy()[0]
2430
- expected_grads = 2.0 * s5.numpy()[0]
2431
- assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2432
-
2433
- tape.zero()
2434
-
2435
-
2436
590
  def test_determinant(test, device, dtype, register_kernels=False):
2437
- np.random.seed(123)
591
+ rng = np.random.default_rng(123)
2438
592
 
2439
593
  tol = {
2440
594
  np.float16: 5.0e-3,
@@ -2464,9 +618,9 @@ def test_determinant(test, device, dtype, register_kernels=False):
2464
618
  if register_kernels:
2465
619
  return
2466
620
 
2467
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2468
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2469
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
621
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
622
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
623
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2470
624
  det2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2471
625
  det3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2472
626
  det4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -2586,210 +740,114 @@ def test_determinant(test, device, dtype, register_kernels=False):
2586
740
  ],
2587
741
  outputs=[
2588
742
  det2,
2589
- det3,
2590
- det4,
2591
- ],
2592
- device=device,
2593
- )
2594
- dminus = det3.numpy()[0]
2595
- assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v3grads[i, j] / dplus, tol=fdtol)
2596
-
2597
- for i in range(4):
2598
- for j in range(4):
2599
- v4test = v4.numpy()
2600
- v4test[0, i, j] += dx
2601
- wp.launch(
2602
- kernel,
2603
- dim=1,
2604
- inputs=[
2605
- v2,
2606
- v3,
2607
- wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device),
2608
- ],
2609
- outputs=[
2610
- det2,
2611
- det3,
2612
- det4,
2613
- ],
2614
- device=device,
2615
- )
2616
- dplus = det4.numpy()[0]
2617
- v4test[0, i, j] -= 2.0 * dx
2618
- wp.launch(
2619
- kernel,
2620
- dim=1,
2621
- inputs=[
2622
- v2,
2623
- v3,
2624
- wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device),
2625
- ],
2626
- outputs=[
2627
- det2,
2628
- det3,
2629
- det4,
2630
- ],
2631
- device=device,
2632
- )
2633
- dminus = det4.numpy()[0]
2634
- assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v4grads[i, j] / dplus, tol=fdtol)
2635
-
2636
-
2637
- def test_trace(test, device, dtype, register_kernels=False):
2638
- np.random.seed(123)
2639
-
2640
- tol = {
2641
- np.float16: 1.0e-3,
2642
- np.float32: 1.0e-6,
2643
- np.float64: 1.0e-8,
2644
- }.get(dtype, 0)
2645
-
2646
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2647
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2648
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2649
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2650
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2651
-
2652
- def check_mat_trace(
2653
- v2: wp.array(dtype=mat22),
2654
- v3: wp.array(dtype=mat33),
2655
- v4: wp.array(dtype=mat44),
2656
- v5: wp.array(dtype=mat55),
2657
- tr2: wp.array(dtype=wptype),
2658
- tr3: wp.array(dtype=wptype),
2659
- tr4: wp.array(dtype=wptype),
2660
- tr5: wp.array(dtype=wptype),
2661
- ):
2662
- # multiply outputs by 2 so we've got something to backpropagate:
2663
- tr2[0] = wptype(2) * wp.trace(v2[0])
2664
- tr3[0] = wptype(2) * wp.trace(v3[0])
2665
- tr4[0] = wptype(2) * wp.trace(v4[0])
2666
- tr5[0] = wptype(2) * wp.trace(v5[0])
2667
-
2668
- kernel = getkernel(check_mat_trace, suffix=dtype.__name__)
2669
-
2670
- if register_kernels:
2671
- return
2672
-
2673
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2674
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2675
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2676
- v5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2677
- tr2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2678
- tr3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2679
- tr4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2680
- tr5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2681
-
2682
- tape = wp.Tape()
2683
- with tape:
2684
- wp.launch(
2685
- kernel,
2686
- dim=1,
2687
- inputs=[
2688
- v2,
2689
- v3,
2690
- v4,
2691
- v5,
2692
- ],
2693
- outputs=[
2694
- tr2,
2695
- tr3,
2696
- tr4,
2697
- tr5,
2698
- ],
2699
- device=device,
2700
- )
2701
-
2702
- assert_np_equal(tr2.numpy()[0], 2 * np.trace(v2.numpy()[0]), tol=10 * tol)
2703
- assert_np_equal(tr3.numpy()[0], 2 * np.trace(v3.numpy()[0]), tol=10 * tol)
2704
- assert_np_equal(tr4.numpy()[0], 2 * np.trace(v4.numpy()[0]), tol=200 * tol)
2705
- assert_np_equal(tr4.numpy()[0], 2 * np.trace(v4.numpy()[0]), tol=200 * tol)
2706
-
2707
- if dtype in np_float_types:
2708
- tape.backward(loss=tr2)
2709
- vgrads = tape.gradients[v2].numpy()[0]
2710
- assert_np_equal(vgrads, 2.0 * np.eye(2), tol=10 * tol)
2711
- tape.zero()
2712
-
2713
- tape.backward(loss=tr3)
2714
- vgrads = tape.gradients[v3].numpy()[0]
2715
- assert_np_equal(vgrads, 2.0 * np.eye(3), tol=10 * tol)
2716
- tape.zero()
2717
-
2718
- tape.backward(loss=tr4)
2719
- vgrads = tape.gradients[v4].numpy()[0]
2720
- assert_np_equal(vgrads, 2.0 * np.eye(4), tol=10 * tol)
2721
- tape.zero()
2722
-
2723
- tape.backward(loss=tr5)
2724
- vgrads = tape.gradients[v5].numpy()[0]
2725
- assert_np_equal(vgrads, 2.0 * np.eye(5), tol=10 * tol)
2726
- tape.zero()
2727
-
2728
-
2729
- def test_diag(test, device, dtype, register_kernels=False):
2730
- np.random.seed(123)
2731
-
2732
- tol = {
2733
- np.float16: 1.0e-3,
2734
- np.float32: 1.0e-6,
2735
- np.float64: 1.0e-8,
2736
- }.get(dtype, 0)
2737
-
2738
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2739
- vec5 = wp.types.vector(length=5, dtype=wptype)
2740
-
2741
- output_select_kernel = get_select_kernel(wptype)
2742
-
2743
- def check_mat_diag(
2744
- s5: wp.array(dtype=vec5),
2745
- outcomponents: wp.array(dtype=wptype),
2746
- ):
2747
- # multiply outputs by 2 so we've got something to backpropagate:
2748
- m55result = wptype(2) * wp.diag(s5[0])
2749
-
2750
- idx = 0
2751
- for i in range(5):
2752
- for j in range(5):
2753
- outcomponents[idx] = m55result[i, j]
2754
- idx = idx + 1
2755
-
2756
- kernel = getkernel(check_mat_diag, suffix=dtype.__name__)
2757
-
2758
- if register_kernels:
2759
- return
2760
-
2761
- s5 = wp.array(randvals([1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
2762
- outcomponents = wp.zeros(5 * 5, dtype=wptype, requires_grad=True, device=device)
2763
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2764
-
2765
- wp.launch(kernel, dim=1, inputs=[s5], outputs=[outcomponents], device=device)
743
+ det3,
744
+ det4,
745
+ ],
746
+ device=device,
747
+ )
748
+ dminus = det3.numpy()[0]
749
+ assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v3grads[i, j] / dplus, tol=fdtol)
2766
750
 
2767
- assert_np_equal(outcomponents.numpy(), 2 * np.diag(s5.numpy()[0]), tol=tol)
751
+ for i in range(4):
752
+ for j in range(4):
753
+ v4test = v4.numpy()
754
+ v4test[0, i, j] += dx
755
+ wp.launch(
756
+ kernel,
757
+ dim=1,
758
+ inputs=[
759
+ v2,
760
+ v3,
761
+ wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device),
762
+ ],
763
+ outputs=[
764
+ det2,
765
+ det3,
766
+ det4,
767
+ ],
768
+ device=device,
769
+ )
770
+ dplus = det4.numpy()[0]
771
+ v4test[0, i, j] -= 2.0 * dx
772
+ wp.launch(
773
+ kernel,
774
+ dim=1,
775
+ inputs=[
776
+ v2,
777
+ v3,
778
+ wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device),
779
+ ],
780
+ outputs=[
781
+ det2,
782
+ det3,
783
+ det4,
784
+ ],
785
+ device=device,
786
+ )
787
+ dminus = det4.numpy()[0]
788
+ assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v4grads[i, j] / dplus, tol=fdtol)
2768
789
 
2769
- if dtype in np_float_types:
2770
- idx = 0
2771
- for i in range(5):
2772
- for j in range(5):
2773
- tape = wp.Tape()
2774
- with tape:
2775
- wp.launch(kernel, dim=1, inputs=[s5], outputs=[outcomponents], device=device)
2776
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
2777
- tape.backward(loss=out)
2778
- expectedresult = np.zeros(5, dtype=dtype)
2779
- if i == j:
2780
- expectedresult[i] = 2
2781
- assert_np_equal(tape.gradients[s5].numpy()[0], expectedresult, tol=10 * tol)
2782
- tape.zero()
2783
790
 
2784
- idx = idx + 1
791
+ # Unused. Why?
792
+ # def test_get_diag(test, device, dtype, register_kernels=False):
793
+ # tol = {
794
+ # np.float16: 1.0e-3,
795
+ # np.float32: 1.0e-6,
796
+ # np.float64: 1.0e-8,
797
+ # }.get(dtype, 0)
798
+ #
799
+ # wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
800
+ # mat55 = wp.types.vector(shape=(5, 5), dtype=wptype)
801
+ #
802
+ # output_select_kernel = get_select_kernel(wptype)
803
+ #
804
+ # def check_mat_diag(
805
+ # m55: wp.array(dtype=mat55),
806
+ # outcomponents: wp.array(dtype=wptype),
807
+ # ):
808
+ # # multiply outputs by 2 so we've got something to backpropagate:
809
+ # vec5result = wptype(2) * wp.get_diag(m55[0])
810
+ #
811
+ # idx = 0
812
+ # for i in range(5):
813
+ # outcomponents[idx] = vec5result[i]
814
+ # idx = idx + 1
815
+ #
816
+ # kernel = getkernel(check_mat_diag, suffix=dtype.__name__)
817
+ #
818
+ # if register_kernels:
819
+ # return
820
+ #
821
+ # m55 = wp.array(randvals((1, 5, 5), dtype), dtype=mat55, requires_grad=True, device=device)
822
+ # outcomponents = wp.zeros(5, dtype=wptype, requires_grad=True, device=device)
823
+ # out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
824
+ #
825
+ # wp.launch(kernel, dim=1, inputs=[m55], outputs=[outcomponents], device=device)
826
+ #
827
+ # assert_np_equal(outcomponents.numpy(), 2 * np.diag(m55.numpy()[0]), tol=tol)
828
+ #
829
+ # if dtype in np_float_types:
830
+ # idx = 0
831
+ # for i in range(5):
832
+ # tape = wp.Tape()
833
+ # with tape:
834
+ # wp.launch(kernel, dim=1, inputs=[m55], outputs=[outcomponents], device=device)
835
+ # wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
836
+ # tape.backward(loss=out)
837
+ # expectedresult = np.zeros((5, 5), dtype=dtype)
838
+ # expectedresult[i, i] = 2
839
+ # assert_np_equal(tape.gradients[m55].numpy()[0], expectedresult, tol=10 * tol)
840
+ # tape.zero()
841
+ #
842
+ # idx = idx + 1
2785
843
 
2786
844
 
2787
845
  def test_inverse(test, device, dtype, register_kernels=False):
2788
- np.random.seed(123)
846
+ rng = np.random.default_rng(123)
2789
847
 
2790
848
  tol = {
2791
- np.float16: 2.0e-3,
2792
- np.float32: 1.0e-6,
849
+ np.float16: 5.0e-2,
850
+ np.float32: 1.0e-5,
2793
851
  np.float64: 1.0e-8,
2794
852
  }.get(dtype, 0)
2795
853
 
@@ -2832,9 +890,15 @@ def test_inverse(test, device, dtype, register_kernels=False):
2832
890
  if register_kernels:
2833
891
  return
2834
892
 
2835
- m2 = wp.array(2 * (randvals([1, 2, 2], dtype) + 0.2 * np.eye(2)), dtype=mat22, requires_grad=True, device=device)
2836
- m3 = wp.array(2 * (randvals([1, 3, 3], dtype) + 0.2 * np.eye(3)), dtype=mat33, requires_grad=True, device=device)
2837
- m4 = wp.array(2 * (randvals([1, 4, 4], dtype) + 0.2 * np.eye(4)), dtype=mat44, requires_grad=True, device=device)
893
+ m2 = wp.array(
894
+ 2 * (randvals(rng, [1, 2, 2], dtype) + 0.2 * np.eye(2)), dtype=mat22, requires_grad=True, device=device
895
+ )
896
+ m3 = wp.array(
897
+ 2 * (randvals(rng, [1, 3, 3], dtype) + 0.2 * np.eye(3)), dtype=mat33, requires_grad=True, device=device
898
+ )
899
+ m4 = wp.array(
900
+ 2 * (randvals(rng, [1, 4, 4], dtype) + 0.2 * np.eye(4)), dtype=mat44, requires_grad=True, device=device
901
+ )
2838
902
 
2839
903
  outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4, dtype=wptype, requires_grad=True, device=device)
2840
904
  out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -2949,7 +1013,7 @@ def test_inverse(test, device, dtype, register_kernels=False):
2949
1013
 
2950
1014
 
2951
1015
  def test_svd(test, device, dtype, register_kernels=False):
2952
- np.random.seed(123)
1016
+ rng = np.random.default_rng(123)
2953
1017
 
2954
1018
  tol = {
2955
1019
  np.float16: 1.0e-3,
@@ -3001,7 +1065,7 @@ def test_svd(test, device, dtype, register_kernels=False):
3001
1065
  if register_kernels:
3002
1066
  return
3003
1067
 
3004
- m3 = wp.array(randvals([1, 3, 3], dtype) + np.eye(3), dtype=mat33, requires_grad=True, device=device)
1068
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype) + np.eye(3), dtype=mat33, requires_grad=True, device=device)
3005
1069
 
3006
1070
  outcomponents = wp.zeros(2 * 3 * 3 + 3, dtype=wptype, requires_grad=True, device=device)
3007
1071
  Uout = wp.zeros(1, dtype=mat33, requires_grad=True, device=device)
@@ -3068,7 +1132,7 @@ def test_svd(test, device, dtype, register_kernels=False):
3068
1132
 
3069
1133
 
3070
1134
  def test_qr(test, device, dtype, register_kernels=False):
3071
- np.random.seed(123)
1135
+ rng = np.random.default_rng(123)
3072
1136
 
3073
1137
  tol = {
3074
1138
  np.float16: 2.0e-3,
@@ -3111,7 +1175,7 @@ def test_qr(test, device, dtype, register_kernels=False):
3111
1175
  if register_kernels:
3112
1176
  return
3113
1177
 
3114
- m3 = wp.array(0.5 * (randvals([1, 3, 3], dtype) + np.eye(3)), dtype=mat33, requires_grad=True, device=device)
1178
+ m3 = wp.array(0.5 * (randvals(rng, [1, 3, 3], dtype) + np.eye(3)), dtype=mat33, requires_grad=True, device=device)
3115
1179
 
3116
1180
  outcomponents = wp.zeros(2 * 3 * 3, dtype=wptype, requires_grad=True, device=device)
3117
1181
  Qout = wp.zeros(1, dtype=mat33, requires_grad=True, device=device)
@@ -3180,7 +1244,7 @@ def test_qr(test, device, dtype, register_kernels=False):
3180
1244
 
3181
1245
 
3182
1246
  def test_eig(test, device, dtype, register_kernels=False):
3183
- np.random.seed(123)
1247
+ rng = np.random.default_rng(123)
3184
1248
 
3185
1249
  tol = {
3186
1250
  np.float16: 4.0e-2,
@@ -3223,7 +1287,7 @@ def test_eig(test, device, dtype, register_kernels=False):
3223
1287
  if register_kernels:
3224
1288
  return
3225
1289
 
3226
- m3_np = randvals([1, 3, 3], dtype) + np.eye(3, dtype=dtype)
1290
+ m3_np = randvals(rng, [1, 3, 3], dtype) + np.eye(3, dtype=dtype)
3227
1291
  m3 = wp.array(m3_np, dtype=mat33, requires_grad=True, device=device)
3228
1292
 
3229
1293
  outcomponents = wp.zeros(3 * 3 + 3, dtype=wptype, requires_grad=True, device=device)
@@ -3292,7 +1356,7 @@ def test_eig(test, device, dtype, register_kernels=False):
3292
1356
 
3293
1357
 
3294
1358
  def test_skew(test, device, dtype, register_kernels=False):
3295
- np.random.seed(123)
1359
+ rng = np.random.default_rng(123)
3296
1360
 
3297
1361
  tol = {
3298
1362
  np.float16: 1.0e-3,
@@ -3323,7 +1387,7 @@ def test_skew(test, device, dtype, register_kernels=False):
3323
1387
  if register_kernels:
3324
1388
  return
3325
1389
 
3326
- v3 = wp.array(randvals([1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1390
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
3327
1391
 
3328
1392
  outcomponents = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
3329
1393
  out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -3394,7 +1458,7 @@ def test_skew(test, device, dtype, register_kernels=False):
3394
1458
 
3395
1459
 
3396
1460
  def test_transform_point(test, device, dtype, register_kernels=False):
3397
- np.random.seed(123)
1461
+ rng = np.random.default_rng(123)
3398
1462
 
3399
1463
  tol = {
3400
1464
  np.float16: 5.0e-3,
@@ -3425,8 +1489,8 @@ def test_transform_point(test, device, dtype, register_kernels=False):
3425
1489
  if register_kernels:
3426
1490
  return
3427
1491
 
3428
- v3 = wp.array(randvals([1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
3429
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1492
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1493
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
3430
1494
 
3431
1495
  outcomponents = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
3432
1496
  out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -3455,7 +1519,7 @@ def test_transform_point(test, device, dtype, register_kernels=False):
3455
1519
 
3456
1520
 
3457
1521
  def test_transform_vector(test, device, dtype, register_kernels=False):
3458
- np.random.seed(123)
1522
+ rng = np.random.default_rng(123)
3459
1523
 
3460
1524
  tol = {
3461
1525
  np.float16: 5.0e-3,
@@ -3486,8 +1550,8 @@ def test_transform_vector(test, device, dtype, register_kernels=False):
3486
1550
  if register_kernels:
3487
1551
  return
3488
1552
 
3489
- v3 = wp.array(randvals([1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
3490
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1553
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1554
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
3491
1555
 
3492
1556
  outcomponents = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
3493
1557
  out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -3514,338 +1578,6 @@ def test_transform_vector(test, device, dtype, register_kernels=False):
3514
1578
  tape.zero()
3515
1579
 
3516
1580
 
3517
- def test_anon_type_instance(test, device, dtype, register_kernels=False):
3518
- np.random.seed(123)
3519
-
3520
- tol = {
3521
- np.float16: 5.0e-3,
3522
- np.float32: 1.0e-6,
3523
- np.float64: 1.0e-8,
3524
- }.get(dtype, 0)
3525
-
3526
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
3527
-
3528
- def check_scalar_init(
3529
- input: wp.array(dtype=wptype),
3530
- output: wp.array(dtype=wptype),
3531
- ):
3532
- m2result = wp.matrix(input[0], shape=(2, 2))
3533
- m3result = wp.matrix(input[1], shape=(3, 3))
3534
- m4result = wp.matrix(input[2], shape=(4, 4))
3535
- m5result = wp.matrix(input[3], shape=(5, 5))
3536
- m32result = wp.matrix(input[4], shape=(3, 2))
3537
-
3538
- idx = 0
3539
- for i in range(2):
3540
- for j in range(2):
3541
- output[idx] = wptype(2) * m2result[i, j]
3542
- idx = idx + 1
3543
- for i in range(3):
3544
- for j in range(3):
3545
- output[idx] = wptype(2) * m3result[i, j]
3546
- idx = idx + 1
3547
- for i in range(4):
3548
- for j in range(4):
3549
- output[idx] = wptype(2) * m4result[i, j]
3550
- idx = idx + 1
3551
- for i in range(5):
3552
- for j in range(5):
3553
- output[idx] = wptype(2) * m5result[i, j]
3554
- idx = idx + 1
3555
- for i in range(3):
3556
- for j in range(2):
3557
- output[idx] = wptype(2) * m32result[i, j]
3558
- idx = idx + 1
3559
-
3560
- def check_component_init(
3561
- input: wp.array(dtype=wptype),
3562
- output: wp.array(dtype=wptype),
3563
- ):
3564
- m2result = wp.matrix(input[0], input[1], input[2], input[3], shape=(2, 2))
3565
- m3result = wp.matrix(
3566
- input[4], input[5], input[6], input[7], input[8], input[9], input[10], input[11], input[12], shape=(3, 3)
3567
- )
3568
- m4result = wp.matrix(
3569
- input[13],
3570
- input[14],
3571
- input[15],
3572
- input[16],
3573
- input[17],
3574
- input[18],
3575
- input[19],
3576
- input[20],
3577
- input[21],
3578
- input[22],
3579
- input[23],
3580
- input[24],
3581
- input[25],
3582
- input[26],
3583
- input[27],
3584
- input[28],
3585
- shape=(4, 4),
3586
- )
3587
- m5result = wp.matrix(
3588
- input[29],
3589
- input[30],
3590
- input[31],
3591
- input[32],
3592
- input[33],
3593
- input[34],
3594
- input[35],
3595
- input[36],
3596
- input[37],
3597
- input[38],
3598
- input[39],
3599
- input[40],
3600
- input[41],
3601
- input[42],
3602
- input[43],
3603
- input[44],
3604
- input[45],
3605
- input[46],
3606
- input[47],
3607
- input[48],
3608
- input[49],
3609
- input[50],
3610
- input[51],
3611
- input[52],
3612
- input[53],
3613
- shape=(5, 5),
3614
- )
3615
- m32result = wp.matrix(input[54], input[55], input[56], input[57], input[58], input[59], shape=(3, 2))
3616
-
3617
- idx = 0
3618
- for i in range(2):
3619
- for j in range(2):
3620
- output[idx] = wptype(2) * m2result[i, j]
3621
- idx = idx + 1
3622
- for i in range(3):
3623
- for j in range(3):
3624
- output[idx] = wptype(2) * m3result[i, j]
3625
- idx = idx + 1
3626
- for i in range(4):
3627
- for j in range(4):
3628
- output[idx] = wptype(2) * m4result[i, j]
3629
- idx = idx + 1
3630
- for i in range(5):
3631
- for j in range(5):
3632
- output[idx] = wptype(2) * m5result[i, j]
3633
- idx = idx + 1
3634
- for i in range(3):
3635
- for j in range(2):
3636
- output[idx] = wptype(2) * m32result[i, j]
3637
- idx = idx + 1
3638
-
3639
- scalar_kernel = getkernel(check_scalar_init, suffix=dtype.__name__)
3640
- component_kernel = getkernel(check_component_init, suffix=dtype.__name__)
3641
- output_select_kernel = get_select_kernel(wptype)
3642
-
3643
- if register_kernels:
3644
- return
3645
-
3646
- input = wp.array(randvals([5], dtype), requires_grad=True, device=device)
3647
- output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2, dtype=wptype, requires_grad=True, device=device)
3648
-
3649
- wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
3650
-
3651
- assert_np_equal(output.numpy()[:4], 2 * np.array([input.numpy()[0]] * 2 * 2), tol=1.0e-6)
3652
- assert_np_equal(output.numpy()[4:13], 2 * np.array([input.numpy()[1]] * 3 * 3), tol=1.0e-6)
3653
- assert_np_equal(output.numpy()[13:29], 2 * np.array([input.numpy()[2]] * 4 * 4), tol=1.0e-6)
3654
- assert_np_equal(output.numpy()[29:54], 2 * np.array([input.numpy()[3]] * 5 * 5), tol=1.0e-6)
3655
- assert_np_equal(output.numpy()[54:], 2 * np.array([input.numpy()[4]] * 3 * 2), tol=1.0e-6)
3656
-
3657
- if dtype in np_float_types:
3658
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
3659
- for i in range(len(output)):
3660
- tape = wp.Tape()
3661
- with tape:
3662
- wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
3663
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
3664
-
3665
- tape.backward(loss=out)
3666
- expected = np.zeros_like(input.numpy())
3667
- if i < 4:
3668
- expected[0] = 2
3669
- elif i < 13:
3670
- expected[1] = 2
3671
- elif i < 29:
3672
- expected[2] = 2
3673
- elif i < 54:
3674
- expected[3] = 2
3675
- else:
3676
- expected[4] = 2
3677
-
3678
- assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
3679
-
3680
- tape.reset()
3681
- tape.zero()
3682
-
3683
- input = wp.array(randvals([2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2], dtype), requires_grad=True, device=device)
3684
- output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2, dtype=wptype, requires_grad=True, device=device)
3685
-
3686
- wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
3687
-
3688
- assert_np_equal(output.numpy(), 2 * input.numpy(), tol=1.0e-6)
3689
-
3690
- if dtype in np_float_types:
3691
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
3692
- for i in range(len(output)):
3693
- tape = wp.Tape()
3694
- with tape:
3695
- wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
3696
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
3697
-
3698
- tape.backward(loss=out)
3699
- expected = np.zeros_like(input.numpy())
3700
- expected[i] = 2
3701
-
3702
- assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
3703
-
3704
- tape.reset()
3705
- tape.zero()
3706
-
3707
-
3708
- def test_identity(test, device, dtype, register_kernels=False):
3709
- np.random.seed(123)
3710
-
3711
- tol = {
3712
- np.float16: 5.0e-3,
3713
- np.float32: 1.0e-6,
3714
- np.float64: 1.0e-8,
3715
- }.get(dtype, 0)
3716
-
3717
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
3718
-
3719
- def check_identity_mat(
3720
- output: wp.array(dtype=wptype),
3721
- ):
3722
- m2result = wp.identity(dtype=wptype, n=2)
3723
- m3result = wp.identity(dtype=wptype, n=3)
3724
- m4result = wp.identity(dtype=wptype, n=4)
3725
- m5result = wp.identity(dtype=wptype, n=5)
3726
-
3727
- idx = 0
3728
- for i in range(2):
3729
- for j in range(2):
3730
- output[idx] = wptype(2) * m2result[i, j]
3731
- idx = idx + 1
3732
- for i in range(3):
3733
- for j in range(3):
3734
- output[idx] = wptype(2) * m3result[i, j]
3735
- idx = idx + 1
3736
- for i in range(4):
3737
- for j in range(4):
3738
- output[idx] = wptype(2) * m4result[i, j]
3739
- idx = idx + 1
3740
- for i in range(5):
3741
- for j in range(5):
3742
- output[idx] = wptype(2) * m5result[i, j]
3743
- idx = idx + 1
3744
-
3745
- id_kernel = getkernel(check_identity_mat, suffix=dtype.__name__)
3746
-
3747
- if register_kernels:
3748
- return
3749
-
3750
- output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
3751
- wp.launch(id_kernel, dim=1, inputs=[], outputs=[output], device=device)
3752
- assert_np_equal(output.numpy()[:4], 2 * np.eye(2), tol=1.0e-6)
3753
- assert_np_equal(output.numpy()[4:13], 2 * np.eye(3), tol=1.0e-6)
3754
- assert_np_equal(output.numpy()[13:29], 2 * np.eye(4), tol=1.0e-6)
3755
- assert_np_equal(output.numpy()[29:], 2 * np.eye(5), tol=1.0e-6)
3756
-
3757
-
3758
- def test_equivalent_types(test, device, dtype, register_kernels=False):
3759
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
3760
-
3761
- # matrix types
3762
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
3763
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
3764
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
3765
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
3766
-
3767
- # matrix types equivalent to the above
3768
- mat22_equiv = wp.types.matrix(shape=(2, 2), dtype=wptype)
3769
- mat33_equiv = wp.types.matrix(shape=(3, 3), dtype=wptype)
3770
- mat44_equiv = wp.types.matrix(shape=(4, 4), dtype=wptype)
3771
- mat55_equiv = wp.types.matrix(shape=(5, 5), dtype=wptype)
3772
-
3773
- # declare kernel with original types
3774
- def check_equivalence(
3775
- m2: mat22,
3776
- m3: mat33,
3777
- m4: mat44,
3778
- m5: mat55,
3779
- ):
3780
- wp.expect_eq(m2, mat22(wptype(42)))
3781
- wp.expect_eq(m3, mat33(wptype(43)))
3782
- wp.expect_eq(m4, mat44(wptype(44)))
3783
- wp.expect_eq(m5, mat55(wptype(45)))
3784
-
3785
- wp.expect_eq(m2, mat22_equiv(wptype(42)))
3786
- wp.expect_eq(m3, mat33_equiv(wptype(43)))
3787
- wp.expect_eq(m4, mat44_equiv(wptype(44)))
3788
- wp.expect_eq(m5, mat55_equiv(wptype(45)))
3789
-
3790
- kernel = getkernel(check_equivalence, suffix=dtype.__name__)
3791
-
3792
- if register_kernels:
3793
- return
3794
-
3795
- # call kernel with equivalent types
3796
- m2 = mat22_equiv(42)
3797
- m3 = mat33_equiv(43)
3798
- m4 = mat44_equiv(44)
3799
- m5 = mat55_equiv(45)
3800
-
3801
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], device=device)
3802
-
3803
-
3804
- def test_conversions(test, device, dtype, register_kernels=False):
3805
- def check_matrices_equal(
3806
- m0: wp.mat22,
3807
- m1: wp.mat22,
3808
- m2: wp.mat22,
3809
- m3: wp.mat22,
3810
- m4: wp.mat22,
3811
- m5: wp.mat22,
3812
- m6: wp.mat22,
3813
- ):
3814
- wp.expect_eq(m1, m0)
3815
- wp.expect_eq(m2, m0)
3816
- wp.expect_eq(m3, m0)
3817
- wp.expect_eq(m4, m0)
3818
- wp.expect_eq(m5, m0)
3819
- wp.expect_eq(m6, m0)
3820
-
3821
- kernel = getkernel(check_matrices_equal, suffix=dtype.__name__)
3822
-
3823
- if register_kernels:
3824
- return
3825
-
3826
- m0 = wp.mat22(1, 2, 3, 4)
3827
-
3828
- # test explicit conversions - constructing matrices from different containers
3829
- m1 = wp.mat22(((1, 2), (3, 4))) # nested tuples
3830
- m2 = wp.mat22([[1, 2], [3, 4]]) # nested lists
3831
- m3 = wp.mat22(np.array([[1, 2], [3, 4]], dtype=dtype)) # 2d array
3832
- m4 = wp.mat22((1, 2, 3, 4)) # flat tuple
3833
- m5 = wp.mat22([1, 2, 3, 4]) # flat list
3834
- m6 = wp.mat22(np.array([1, 2, 3, 4], dtype=dtype)) # 1d array
3835
-
3836
- wp.launch(kernel, dim=1, inputs=[m0, m1, m2, m3, m4, m5, m6], device=device)
3837
-
3838
- # test implicit conversions - passing different containers as matrices to wp.launch()
3839
- m1 = ((1, 2), (3, 4)) # nested tuples
3840
- m2 = [[1, 2], [3, 4]] # nested lists
3841
- m3 = np.array([[1, 2], [3, 4]], dtype=dtype) # 2d array
3842
- m4 = (1, 2, 3, 4) # flat tuple
3843
- m5 = [1, 2, 3, 4] # flat list
3844
- m6 = np.array([1, 2, 3, 4], dtype=dtype) # 1d array
3845
-
3846
- wp.launch(kernel, dim=1, inputs=[m0, m1, m2, m3, m4, m5, m6], device=device)
3847
-
3848
-
3849
1581
  # Test matrix constructors using explicit type (float16)
3850
1582
  # note that these tests are specifically not using generics / closure
3851
1583
  # args to create kernels dynamically (like the rest of this file)
@@ -3869,6 +1601,22 @@ def test_constructors_explicit_precision():
3869
1601
  wp.expect_eq(custom[i, j], wp.float16(i) * wp.float16(2.0) + wp.float16(j))
3870
1602
 
3871
1603
 
1604
+ mat32d = wp.mat(shape=(3, 2), dtype=wp.float64)
1605
+
1606
+
1607
+ @wp.kernel
1608
+ def test_matrix_constructor_value_func():
1609
+ a = wp.mat22()
1610
+ b = wp.matrix(a, shape=(2, 2))
1611
+ c = mat32d()
1612
+ d = mat32d(c, shape=(3, 2))
1613
+ e = mat32d(wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0))
1614
+ f = mat32d(
1615
+ wp.vec3d(wp.float64(1.0), wp.float64(2.0), wp.float64(3.0)),
1616
+ wp.vec3d(wp.float64(1.0), wp.float64(2.0), wp.float64(3.0)),
1617
+ )
1618
+
1619
+
3872
1620
  # Same as above but with a default (float/int) type
3873
1621
  # which tests some different code paths that
3874
1622
  # need to ensure types are correctly canonicalized
@@ -3931,167 +1679,149 @@ def test_constructors_constant_shape():
3931
1679
  m[i, j] = float(i * j)
3932
1680
 
3933
1681
 
3934
- def register(parent):
3935
- devices = get_test_devices()
3936
-
3937
- class TestMat(parent):
3938
- pass
3939
-
3940
- add_kernel_test(TestMat, test_constructors_explicit_precision, dim=1, devices=devices)
3941
- add_kernel_test(TestMat, test_constructors_default_precision, dim=1, devices=devices)
3942
- add_kernel_test(TestMat, test_constructors_constant_shape, dim=1, devices=devices)
3943
-
3944
- mat103 = wp.types.matrix(shape=(10, 3), dtype=float)
3945
- add_kernel_test(
3946
- TestMat,
3947
- test_matrix_mutation,
3948
- dim=1,
3949
- inputs=[
3950
- mat103(
3951
- 1.0,
3952
- 2.0,
3953
- 3.0,
3954
- 2.0,
3955
- 4.0,
3956
- 6.0,
3957
- 3.0,
3958
- 6.0,
3959
- 9.0,
3960
- 4.0,
3961
- 8.0,
3962
- 12.0,
3963
- 5.0,
3964
- 10.0,
3965
- 15.0,
3966
- 6.0,
3967
- 12.0,
3968
- 18.0,
3969
- 7.0,
3970
- 14.0,
3971
- 21.0,
3972
- 8.0,
3973
- 16.0,
3974
- 24.0,
3975
- 9.0,
3976
- 18.0,
3977
- 27.0,
3978
- 10.0,
3979
- 20.0,
3980
- 30.0,
3981
- )
3982
- ],
3983
- devices=devices,
3984
- )
3985
-
3986
- for dtype in np_signed_int_types + np_float_types:
3987
- add_function_test_register_kernel(
3988
- TestMat, f"test_negation_{dtype.__name__}", test_negation, devices=devices, dtype=dtype
3989
- )
3990
- add_function_test_register_kernel(
3991
- TestMat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
3992
- )
3993
-
3994
- for dtype in np_scalar_types:
3995
- add_function_test(TestMat, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
3996
- add_function_test_register_kernel(
3997
- TestMat, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
3998
- )
3999
- add_function_test_register_kernel(
4000
- TestMat, f"test_anon_type_instance_{dtype.__name__}", test_anon_type_instance, devices=devices, dtype=dtype
4001
- )
4002
- add_function_test_register_kernel(
4003
- TestMat, f"test_identity_{dtype.__name__}", test_identity, devices=devices, dtype=dtype
4004
- )
4005
- add_function_test_register_kernel(
4006
- TestMat, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
4007
- )
4008
- add_function_test_register_kernel(
4009
- TestMat, f"test_equality_{dtype.__name__}", test_equality, devices=devices, dtype=dtype
4010
- )
4011
- add_function_test_register_kernel(
4012
- TestMat,
4013
- f"test_scalar_multiplication_{dtype.__name__}",
4014
- test_scalar_multiplication,
4015
- devices=devices,
4016
- dtype=dtype,
4017
- )
4018
- add_function_test_register_kernel(
4019
- TestMat,
4020
- f"test_matvec_multiplication_{dtype.__name__}",
4021
- test_matvec_multiplication,
4022
- devices=devices,
4023
- dtype=dtype,
4024
- )
4025
- add_function_test_register_kernel(
4026
- TestMat,
4027
- f"test_matmat_multiplication_{dtype.__name__}",
4028
- test_matmat_multiplication,
4029
- devices=devices,
4030
- dtype=dtype,
4031
- )
4032
- add_function_test_register_kernel(
4033
- TestMat, f"test_cw_multiplication_{dtype.__name__}", test_cw_multiplication, devices=devices, dtype=dtype
4034
- )
4035
- add_function_test_register_kernel(
4036
- TestMat, f"test_cw_division_{dtype.__name__}", test_cw_division, devices=devices, dtype=dtype
4037
- )
4038
- add_function_test_register_kernel(
4039
- TestMat, f"test_outer_product_{dtype.__name__}", test_outer_product, devices=devices, dtype=dtype
4040
- )
4041
- add_function_test_register_kernel(
4042
- TestMat, f"test_transpose_{dtype.__name__}", test_transpose, devices=devices, dtype=dtype
4043
- )
4044
- add_function_test_register_kernel(
4045
- TestMat, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
4046
- )
4047
- add_function_test_register_kernel(
4048
- TestMat, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
4049
- )
4050
- add_function_test_register_kernel(
4051
- TestMat, f"test_ddot_{dtype.__name__}", test_ddot, devices=devices, dtype=dtype
4052
- )
4053
- add_function_test_register_kernel(
4054
- TestMat, f"test_trace_{dtype.__name__}", test_trace, devices=devices, dtype=dtype
4055
- )
4056
- add_function_test_register_kernel(
4057
- TestMat, f"test_diag_{dtype.__name__}", test_diag, devices=devices, dtype=dtype
4058
- )
4059
- add_function_test_register_kernel(
4060
- TestMat, f"test_equivalent_types_{dtype.__name__}", test_equivalent_types, devices=devices, dtype=dtype
4061
- )
4062
- add_function_test_register_kernel(
4063
- TestMat, f"test_conversions_{dtype.__name__}", test_conversions, devices=devices, dtype=dtype
4064
- )
4065
- add_function_test_register_kernel(
4066
- TestMat, f"test_constants_{dtype.__name__}", test_constants, devices=devices, dtype=dtype
1682
+ devices = get_test_devices()
1683
+
1684
+
1685
+ class TestMat(unittest.TestCase):
1686
+ pass
1687
+
1688
+
1689
+ add_kernel_test(TestMat, test_constructors_explicit_precision, dim=1, devices=devices)
1690
+ add_kernel_test(TestMat, test_constructors_default_precision, dim=1, devices=devices)
1691
+ add_kernel_test(TestMat, test_constructors_constant_shape, dim=1, devices=devices)
1692
+ add_kernel_test(TestMat, test_matrix_constructor_value_func, dim=1, devices=devices)
1693
+
1694
+ mat103 = wp.types.matrix(shape=(10, 3), dtype=float)
1695
+ add_kernel_test(
1696
+ TestMat,
1697
+ test_matrix_mutation,
1698
+ dim=1,
1699
+ inputs=[
1700
+ mat103(
1701
+ 1.0,
1702
+ 2.0,
1703
+ 3.0,
1704
+ 2.0,
1705
+ 4.0,
1706
+ 6.0,
1707
+ 3.0,
1708
+ 6.0,
1709
+ 9.0,
1710
+ 4.0,
1711
+ 8.0,
1712
+ 12.0,
1713
+ 5.0,
1714
+ 10.0,
1715
+ 15.0,
1716
+ 6.0,
1717
+ 12.0,
1718
+ 18.0,
1719
+ 7.0,
1720
+ 14.0,
1721
+ 21.0,
1722
+ 8.0,
1723
+ 16.0,
1724
+ 24.0,
1725
+ 9.0,
1726
+ 18.0,
1727
+ 27.0,
1728
+ 10.0,
1729
+ 20.0,
1730
+ 30.0,
4067
1731
  )
1732
+ ],
1733
+ devices=devices,
1734
+ )
4068
1735
 
4069
- for dtype in np_float_types:
4070
- add_function_test_register_kernel(
4071
- TestMat, f"test_quat_constructor_{dtype.__name__}", test_quat_constructor, devices=devices, dtype=dtype
4072
- )
4073
- add_function_test_register_kernel(
4074
- TestMat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
4075
- )
4076
- add_function_test_register_kernel(TestMat, f"test_svd_{dtype.__name__}", test_svd, devices=devices, dtype=dtype)
4077
- add_function_test_register_kernel(TestMat, f"test_qr_{dtype.__name__}", test_qr, devices=devices, dtype=dtype)
4078
- add_function_test_register_kernel(TestMat, f"test_eig_{dtype.__name__}", test_eig, devices=devices, dtype=dtype)
4079
- add_function_test_register_kernel(
4080
- TestMat, f"test_transform_point_{dtype.__name__}", test_transform_point, devices=devices, dtype=dtype
4081
- )
4082
- add_function_test_register_kernel(
4083
- TestMat, f"test_transform_vector_{dtype.__name__}", test_transform_vector, devices=devices, dtype=dtype
4084
- )
4085
- add_function_test_register_kernel(
4086
- TestMat, f"test_determinant_{dtype.__name__}", test_determinant, devices=devices, dtype=dtype
4087
- )
4088
- add_function_test_register_kernel(
4089
- TestMat, f"test_skew_{dtype.__name__}", test_skew, devices=devices, dtype=dtype
4090
- )
1736
+ for dtype in np_signed_int_types + np_float_types:
1737
+ add_function_test_register_kernel(
1738
+ TestMat, f"test_negation_{dtype.__name__}", test_negation, devices=devices, dtype=dtype
1739
+ )
1740
+ add_function_test_register_kernel(
1741
+ TestMat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
1742
+ )
4091
1743
 
4092
- return TestMat
1744
+ add_function_test(
1745
+ TestMat,
1746
+ "test_anon_constructor_error_shape_keyword_missing",
1747
+ test_anon_constructor_error_shape_keyword_missing,
1748
+ devices=devices,
1749
+ )
1750
+ add_function_test(
1751
+ TestMat,
1752
+ "test_anon_constructor_error_dtype_keyword_missing",
1753
+ test_anon_constructor_error_dtype_keyword_missing,
1754
+ devices=devices,
1755
+ )
1756
+ add_function_test(
1757
+ TestMat,
1758
+ "test_anon_constructor_error_shape_mismatch",
1759
+ test_anon_constructor_error_shape_mismatch,
1760
+ devices=devices,
1761
+ )
1762
+ add_function_test(
1763
+ TestMat,
1764
+ "test_anon_constructor_error_invalid_arg_count",
1765
+ test_anon_constructor_error_invalid_arg_count,
1766
+ devices=devices,
1767
+ )
1768
+ add_function_test(
1769
+ TestMat,
1770
+ "test_tpl_constructor_error_incompatible_sizes",
1771
+ test_tpl_constructor_error_incompatible_sizes,
1772
+ devices=devices,
1773
+ )
1774
+ add_function_test(
1775
+ TestMat,
1776
+ "test_tpl_constructor_error_invalid_scalar_type",
1777
+ test_tpl_constructor_error_invalid_scalar_type,
1778
+ devices=devices,
1779
+ )
1780
+ add_function_test(
1781
+ TestMat,
1782
+ "test_tpl_constructor_error_invalid_vector_count",
1783
+ test_tpl_constructor_error_invalid_vector_count,
1784
+ devices=devices,
1785
+ )
1786
+ add_function_test(
1787
+ TestMat,
1788
+ "test_tpl_constructor_error_invalid_vector_shape",
1789
+ test_tpl_constructor_error_invalid_vector_shape,
1790
+ devices=devices,
1791
+ )
1792
+ add_function_test(
1793
+ TestMat,
1794
+ "test_tpl_constructor_error_invalid_arg_count",
1795
+ test_tpl_constructor_error_invalid_arg_count,
1796
+ devices=devices,
1797
+ )
1798
+ add_function_test(TestMat, "test_tpl_ops_with_anon", test_tpl_ops_with_anon)
1799
+
1800
+ for dtype in np_float_types:
1801
+ add_function_test(
1802
+ TestMat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
1803
+ )
1804
+ add_function_test_register_kernel(
1805
+ TestMat, f"test_quat_constructor_{dtype.__name__}", test_quat_constructor, devices=devices, dtype=dtype
1806
+ )
1807
+ add_function_test_register_kernel(
1808
+ TestMat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
1809
+ )
1810
+ add_function_test_register_kernel(TestMat, f"test_svd_{dtype.__name__}", test_svd, devices=devices, dtype=dtype)
1811
+ add_function_test_register_kernel(TestMat, f"test_qr_{dtype.__name__}", test_qr, devices=devices, dtype=dtype)
1812
+ add_function_test_register_kernel(TestMat, f"test_eig_{dtype.__name__}", test_eig, devices=devices, dtype=dtype)
1813
+ add_function_test_register_kernel(
1814
+ TestMat, f"test_transform_point_{dtype.__name__}", test_transform_point, devices=devices, dtype=dtype
1815
+ )
1816
+ add_function_test_register_kernel(
1817
+ TestMat, f"test_transform_vector_{dtype.__name__}", test_transform_vector, devices=devices, dtype=dtype
1818
+ )
1819
+ add_function_test_register_kernel(
1820
+ TestMat, f"test_determinant_{dtype.__name__}", test_determinant, devices=devices, dtype=dtype
1821
+ )
1822
+ add_function_test_register_kernel(TestMat, f"test_skew_{dtype.__name__}", test_skew, devices=devices, dtype=dtype)
4093
1823
 
4094
1824
 
4095
1825
  if __name__ == "__main__":
4096
- c = register(unittest.TestCase)
1826
+ wp.build.clear_kernel_cache()
4097
1827
  unittest.main(verbosity=2, failfast=True)