warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/tests/test_mat.py CHANGED
@@ -5,9 +5,12 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import unittest
9
+
8
10
  import numpy as np
11
+
9
12
  import warp as wp
10
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
11
14
 
12
15
  wp.init()
13
16
 
@@ -19,37 +22,24 @@ np_signed_int_types = [
19
22
  np.byte,
20
23
  ]
21
24
 
22
- np_unsigned_int_types = [
23
- np.uint8,
24
- np.uint16,
25
- np.uint32,
26
- np.uint64,
27
- np.ubyte,
28
- ]
29
-
30
- np_int_types = np_signed_int_types + np_unsigned_int_types
31
-
32
25
  np_float_types = [np.float16, np.float32, np.float64]
33
26
 
34
- np_scalar_types = np_int_types + np_float_types
35
-
36
27
 
37
- def randvals(shape, dtype):
28
+ def randvals(rng, shape, dtype):
38
29
  if dtype in np_float_types:
39
- return np.random.randn(*shape).astype(dtype)
30
+ return rng.standard_normal(size=shape).astype(dtype)
40
31
  elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
41
- return np.random.randint(1, 3, size=shape, dtype=dtype)
42
- return np.random.randint(1, 5, size=shape, dtype=dtype)
32
+ return rng.integers(1, high=3, size=shape, dtype=dtype)
33
+ return rng.integers(1, high=5, size=shape, dtype=dtype)
43
34
 
44
35
 
45
36
  kernel_cache = dict()
46
37
 
47
38
 
48
39
  def getkernel(func, suffix=""):
49
- module = wp.get_module(func.__module__)
50
40
  key = func.__name__ + "_" + suffix
51
41
  if key not in kernel_cache:
52
- kernel_cache[key] = wp.Kernel(func=func, key=key, module=module)
42
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
53
43
  return kernel_cache[key]
54
44
 
55
45
 
@@ -63,376 +53,224 @@ def get_select_kernel(dtype):
63
53
 
64
54
  return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
65
55
 
56
+ wp.launch(kernel, dim=1, inputs=[])
66
57
 
67
- def test_arrays(test, device, dtype):
68
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
69
-
70
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
71
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
72
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
73
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
74
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
75
-
76
- np.random.seed(123)
77
58
 
78
- v2_np = randvals([10, 2, 2], dtype)
79
- v3_np = randvals([10, 3, 3], dtype)
80
- v4_np = randvals([10, 4, 4], dtype)
81
- v5_np = randvals([10, 5, 5], dtype)
82
- v32_np = randvals([10, 3, 2], dtype)
59
+ def test_anon_constructor_error_shape_keyword_missing(test, device):
60
+ @wp.kernel
61
+ def kernel():
62
+ wp.matrix(1.0, 2.0, 3.0)
83
63
 
84
- v2 = wp.array(v2_np, dtype=mat22, requires_grad=True, device=device)
85
- v3 = wp.array(v3_np, dtype=mat33, requires_grad=True, device=device)
86
- v4 = wp.array(v4_np, dtype=mat44, requires_grad=True, device=device)
87
- v5 = wp.array(v5_np, dtype=mat55, requires_grad=True, device=device)
88
- v32 = wp.array(v32_np, dtype=mat32, requires_grad=True, device=device)
64
+ with test.assertRaisesRegex(
65
+ RuntimeError,
66
+ r"shape keyword must be specified when calling matrix\(\) function$",
67
+ ):
68
+ wp.launch(
69
+ kernel,
70
+ dim=1,
71
+ inputs=[],
72
+ device=device,
73
+ )
89
74
 
90
- assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
91
- assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
92
- assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
93
- assert_np_equal(v5.numpy(), v5_np, tol=1.0e-6)
94
- assert_np_equal(v32.numpy(), v32_np, tol=1.0e-6)
95
75
 
96
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
97
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
98
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
76
+ def test_anon_constructor_error_dtype_keyword_missing(test, device):
77
+ @wp.kernel
78
+ def kernel():
79
+ wp.matrix(shape=(3, 3))
99
80
 
100
- v2 = wp.array(v2_np, dtype=mat22, requires_grad=True, device=device)
101
- v3 = wp.array(v3_np, dtype=mat33, requires_grad=True, device=device)
102
- v4 = wp.array(v4_np, dtype=mat44, requires_grad=True, device=device)
81
+ with test.assertRaisesRegex(
82
+ RuntimeError,
83
+ r"matrix\(\) must have dtype as a keyword argument if it has no " r"positional arguments$",
84
+ ):
85
+ wp.launch(
86
+ kernel,
87
+ dim=1,
88
+ inputs=[],
89
+ device=device,
90
+ )
103
91
 
104
- assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
105
- assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
106
- assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
107
92
 
93
+ def test_anon_constructor_error_shape_mismatch(test, device):
94
+ @wp.kernel
95
+ def kernel():
96
+ wp.matrix(
97
+ wp.matrix(shape=(1, 2), dtype=float),
98
+ shape=(3, 4),
99
+ dtype=float,
100
+ )
108
101
 
109
- def test_components(test, device, dtype):
110
- # test accessing matrix components from Python - this is especially important
111
- # for float16, which requires special handling internally
102
+ with test.assertRaisesRegex(
103
+ RuntimeError,
104
+ r"Incompatible matrix sizes for casting copy constructor, " r"\(3, 4\) vs \(1, 2\)$",
105
+ ):
106
+ wp.launch(
107
+ kernel,
108
+ dim=1,
109
+ inputs=[],
110
+ device=device,
111
+ )
112
112
 
113
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
114
- mat23 = wp.types.matrix(shape=(2, 3), dtype=wptype)
115
-
116
- m = mat23(1, 2, 3, 4, 5, 6)
117
-
118
- # test __getitem__ for row vectors
119
- r0 = m[0]
120
- r1 = m[1]
121
- test.assertEqual(r0[0], 1)
122
- test.assertEqual(r0[1], 2)
123
- test.assertEqual(r0[2], 3)
124
- test.assertEqual(r1[0], 4)
125
- test.assertEqual(r1[1], 5)
126
- test.assertEqual(r1[2], 6)
127
-
128
- # test __getitem__ for individual components
129
- test.assertEqual(m[0, 0], 1)
130
- test.assertEqual(m[0, 1], 2)
131
- test.assertEqual(m[0, 2], 3)
132
- test.assertEqual(m[1, 0], 4)
133
- test.assertEqual(m[1, 1], 5)
134
- test.assertEqual(m[1, 2], 6)
135
-
136
- # test __setitem__ for row vectors
137
- m[0] = [7, 8, 9]
138
- m[1] = [10, 11, 12]
139
- test.assertEqual(m[0, 0], 7)
140
- test.assertEqual(m[0, 1], 8)
141
- test.assertEqual(m[0, 2], 9)
142
- test.assertEqual(m[1, 0], 10)
143
- test.assertEqual(m[1, 1], 11)
144
- test.assertEqual(m[1, 2], 12)
145
-
146
- # test __setitem__ for individual components
147
- m[0, 0] = 13
148
- m[0, 1] = 14
149
- m[0, 2] = 15
150
- m[1, 0] = 16
151
- m[1, 1] = 17
152
- m[1, 2] = 18
153
- test.assertEqual(m[0, 0], 13)
154
- test.assertEqual(m[0, 1], 14)
155
- test.assertEqual(m[0, 2], 15)
156
- test.assertEqual(m[1, 0], 16)
157
- test.assertEqual(m[1, 1], 17)
158
- test.assertEqual(m[1, 2], 18)
159
-
160
-
161
- def test_constants(test, device, dtype, register_kernels=False):
162
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
163
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
164
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
165
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
166
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
167
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
168
113
 
169
- cm22 = wp.constant(mat22(22))
170
- cm33 = wp.constant(mat33(33))
171
- cm44 = wp.constant(mat44(44))
172
- cm55 = wp.constant(mat55(55))
173
- cm32 = wp.constant(mat32(32))
114
+ def test_anon_constructor_error_invalid_arg_count(test, device):
115
+ @wp.kernel
116
+ def kernel():
117
+ wp.matrix(1.0, 2.0, 3.0, shape=(2, 2), dtype=float)
174
118
 
175
- def check_matrix_constants():
176
- wp.expect_eq(cm22, mat22(wptype(22)))
177
- wp.expect_eq(cm33, mat33(wptype(33)))
178
- wp.expect_eq(cm44, mat44(wptype(44)))
179
- wp.expect_eq(cm55, mat55(wptype(55)))
180
- wp.expect_eq(cm32, mat32(wptype(32)))
119
+ with test.assertRaisesRegex(
120
+ RuntimeError,
121
+ r"Wrong number of arguments for matrix\(\) function, must initialize "
122
+ r"with either a scalar value, or m\*n values$",
123
+ ):
124
+ wp.launch(
125
+ kernel,
126
+ dim=1,
127
+ inputs=[],
128
+ device=device,
129
+ )
181
130
 
182
- kernel = getkernel(check_matrix_constants, suffix=dtype.__name__)
183
131
 
184
- if register_kernels:
185
- return
132
+ def test_tpl_constructor_error_incompatible_sizes(test, device):
133
+ @wp.kernel
134
+ def kernel():
135
+ wp.mat33(wp.mat22(1.0, 2.0, 3.0, 4.0))
186
136
 
187
- wp.launch(kernel, dim=1, inputs=[])
137
+ with test.assertRaisesRegex(
138
+ RuntimeError,
139
+ r"Incompatible matrix sizes for casting copy constructor, " r"\(3, 3\) vs \(2, 2\)$",
140
+ ):
141
+ wp.launch(
142
+ kernel,
143
+ dim=1,
144
+ inputs=[],
145
+ device=device,
146
+ )
188
147
 
189
148
 
190
- def test_constructors(test, device, dtype, register_kernels=False):
191
- np.random.seed(123)
149
+ def test_tpl_constructor_error_invalid_scalar_type(test, device):
150
+ @wp.kernel
151
+ def kernel():
152
+ wp.mat22(1, 2, 3, 4)
192
153
 
193
- tol = {
194
- np.float16: 1.0e-3,
195
- np.float32: 1.0e-6,
196
- np.float64: 1.0e-8,
197
- }.get(dtype, 0)
154
+ with test.assertRaisesRegex(
155
+ RuntimeError,
156
+ r"Wrong scalar type for mat 2,2,<class 'warp.types.float32'> constructor$",
157
+ ):
158
+ wp.launch(
159
+ kernel,
160
+ dim=1,
161
+ inputs=[],
162
+ device=device,
163
+ )
198
164
 
199
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
200
- vec2 = wp.types.vector(length=2, dtype=wptype)
201
- vec3 = wp.types.vector(length=3, dtype=wptype)
202
- vec4 = wp.types.vector(length=4, dtype=wptype)
203
- vec5 = wp.types.vector(length=5, dtype=wptype)
204
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
205
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
206
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
207
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
208
165
 
209
- output_select_kernel = get_select_kernel(wptype)
166
+ def test_tpl_constructor_error_invalid_vector_count(test, device):
167
+ @wp.kernel
168
+ def kernel():
169
+ wp.mat22(wp.vec3(1.0, 2.0, 3.0))
210
170
 
211
- def check_scalar_mat_constructor(
212
- input: wp.array(dtype=wptype),
213
- outcomponents: wp.array(dtype=wptype),
171
+ with test.assertRaisesRegex(
172
+ RuntimeError,
173
+ r"Wrong number of vectors when attempting to construct a matrix " r"with column vectors$",
214
174
  ):
215
- # multiply outputs by 2 so we've got something to backpropagate:
216
- m2result = wptype(2) * mat22(input[0])
217
- m3result = wptype(2) * mat33(input[0])
218
- m4result = wptype(2) * mat44(input[0])
219
- m5result = wptype(2) * mat55(input[0])
220
-
221
- idx = 0
222
- for i in range(2):
223
- for j in range(2):
224
- outcomponents[idx] = m2result[i, j]
225
- idx = idx + 1
226
-
227
- for i in range(3):
228
- for j in range(3):
229
- outcomponents[idx] = m3result[i, j]
230
- idx = idx + 1
175
+ wp.launch(
176
+ kernel,
177
+ dim=1,
178
+ inputs=[],
179
+ device=device,
180
+ )
231
181
 
232
- for i in range(4):
233
- for j in range(4):
234
- outcomponents[idx] = m4result[i, j]
235
- idx = idx + 1
236
182
 
237
- for i in range(5):
238
- for j in range(5):
239
- outcomponents[idx] = m5result[i, j]
240
- idx = idx + 1
183
+ def test_tpl_constructor_error_invalid_vector_shape(test, device):
184
+ @wp.kernel
185
+ def kernel():
186
+ wp.mat22(wp.vec3(1.0, 2.0, 3.0), wp.vec3(4.0, 5.0, 6.0))
241
187
 
242
- def check_component_mat_constructor(
243
- input: wp.array(dtype=wptype),
244
- outcomponents: wp.array(dtype=wptype),
188
+ with test.assertRaisesRegex(
189
+ RuntimeError,
190
+ r"Wrong vector row count when attempting to construct a matrix " r"with column vectors$",
245
191
  ):
246
- # multiply outputs by 2 so we've got something to backpropagate:
247
- m2result = wptype(2) * mat22(input[0], input[1], input[2], input[3])
248
- m3result = wptype(2) * mat33(
249
- input[4],
250
- input[5],
251
- input[6],
252
- input[7],
253
- input[8],
254
- input[9],
255
- input[10],
256
- input[11],
257
- input[12],
258
- )
259
- m4result = wptype(2) * mat44(
260
- input[13],
261
- input[14],
262
- input[15],
263
- input[16],
264
- input[17],
265
- input[18],
266
- input[19],
267
- input[20],
268
- input[21],
269
- input[22],
270
- input[23],
271
- input[24],
272
- input[25],
273
- input[26],
274
- input[27],
275
- input[28],
276
- )
277
- m5result = wptype(2) * mat55(
278
- input[29],
279
- input[30],
280
- input[31],
281
- input[32],
282
- input[33],
283
- input[34],
284
- input[35],
285
- input[36],
286
- input[37],
287
- input[38],
288
- input[39],
289
- input[40],
290
- input[41],
291
- input[42],
292
- input[43],
293
- input[44],
294
- input[45],
295
- input[46],
296
- input[47],
297
- input[48],
298
- input[49],
299
- input[50],
300
- input[51],
301
- input[52],
302
- input[53],
192
+ wp.launch(
193
+ kernel,
194
+ dim=1,
195
+ inputs=[],
196
+ device=device,
303
197
  )
304
198
 
305
- idx = 0
306
- for i in range(2):
307
- for j in range(2):
308
- outcomponents[idx] = m2result[i, j]
309
- idx = idx + 1
310
-
311
- for i in range(3):
312
- for j in range(3):
313
- outcomponents[idx] = m3result[i, j]
314
- idx = idx + 1
315
-
316
- for i in range(4):
317
- for j in range(4):
318
- outcomponents[idx] = m4result[i, j]
319
- idx = idx + 1
320
199
 
321
- for i in range(5):
322
- for j in range(5):
323
- outcomponents[idx] = m5result[i, j]
324
- idx = idx + 1
200
+ def test_tpl_constructor_error_invalid_arg_count(test, device):
201
+ @wp.kernel
202
+ def kernel():
203
+ wp.mat22(1.0, 2.0, 3.0)
325
204
 
326
- def check_vector_mat_constructor(
327
- input: wp.array(dtype=wptype),
328
- outcomponents: wp.array(dtype=wptype),
205
+ with test.assertRaisesRegex(
206
+ RuntimeError,
207
+ r"Wrong number of scalars when attempting to construct a matrix " r"from a list of components$",
329
208
  ):
330
- # multiply outputs by 2 so we've got something to backpropagate:
331
- m2result = wptype(2) * mat22(vec2(input[0], input[2]), vec2(input[1], input[3]))
332
- m3result = wptype(2) * mat33(
333
- vec3(input[4], input[7], input[10]),
334
- vec3(input[5], input[8], input[11]),
335
- vec3(input[6], input[9], input[12]),
336
- )
337
- m4result = wptype(2) * mat44(
338
- vec4(input[13], input[17], input[21], input[25]),
339
- vec4(input[14], input[18], input[22], input[26]),
340
- vec4(input[15], input[19], input[23], input[27]),
341
- vec4(input[16], input[20], input[24], input[28]),
342
- )
343
- m5result = wptype(2) * mat55(
344
- vec5(input[29], input[34], input[39], input[44], input[49]),
345
- vec5(input[30], input[35], input[40], input[45], input[50]),
346
- vec5(input[31], input[36], input[41], input[46], input[51]),
347
- vec5(input[32], input[37], input[42], input[47], input[52]),
348
- vec5(input[33], input[38], input[43], input[48], input[53]),
209
+ wp.launch(
210
+ kernel,
211
+ dim=1,
212
+ inputs=[],
213
+ device=device,
349
214
  )
350
215
 
351
- idx = 0
352
- for i in range(2):
353
- for j in range(2):
354
- outcomponents[idx] = m2result[i, j]
355
- idx = idx + 1
356
-
357
- for i in range(3):
358
- for j in range(3):
359
- outcomponents[idx] = m3result[i, j]
360
- idx = idx + 1
361
-
362
- for i in range(4):
363
- for j in range(4):
364
- outcomponents[idx] = m4result[i, j]
365
- idx = idx + 1
366
216
 
367
- for i in range(5):
368
- for j in range(5):
369
- outcomponents[idx] = m5result[i, j]
370
- idx = idx + 1
217
+ def test_tpl_ops_with_anon(test, device):
218
+ mat22f = wp.mat((2, 2), dtype=float)
371
219
 
372
- kernel = getkernel(check_scalar_mat_constructor, suffix=dtype.__name__)
373
- compkernel = getkernel(check_component_mat_constructor, suffix=dtype.__name__)
374
- veckernel = getkernel(check_vector_mat_constructor, suffix=dtype.__name__)
220
+ m = wp.mat22f(1.0, 2.0, 3.0, 4.0)
221
+ m += mat22f(2.0, 3.0, 4.0, 5.0)
222
+ m -= mat22f(3.0, 4.0, 5.0, 6.0)
223
+ test.assertSequenceEqual(m, ((0.0, 1.0), (2.0, 3.0)))
375
224
 
376
- if register_kernels:
377
- return
225
+ m = mat22f(1.0, 2.0, 3.0, 4.0)
226
+ m += wp.mat22f(2.0, 3.0, 4.0, 5.0)
227
+ m -= wp.mat22f(3.0, 4.0, 5.0, 6.0)
228
+ test.assertSequenceEqual(m, ((0.0, 1.0), (2.0, 3.0)))
378
229
 
379
- input = wp.array(randvals([1], dtype), requires_grad=True, device=device)
380
- val = input.numpy()[0]
381
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
382
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
383
230
 
384
- wp.launch(kernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
231
+ def test_py_arithmetic_ops(test, device, dtype):
232
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
385
233
 
386
- assert_np_equal(outcomponents.numpy()[:4], 2 * val * np.ones(2 * 2), tol=tol)
387
- assert_np_equal(outcomponents.numpy()[4:13], 2 * val * np.ones(3 * 3), tol=tol)
388
- assert_np_equal(outcomponents.numpy()[13:29], 2 * val * np.ones(4 * 4), tol=tol)
389
- assert_np_equal(outcomponents.numpy()[29:54], 2 * val * np.ones(5 * 5), tol=tol)
234
+ def make_mat(*args):
235
+ if wptype in wp.types.int_types:
236
+ # Cast to the correct integer type to simulate wrapping.
237
+ return tuple(tuple(wptype._type_(x).value for x in row) for row in args)
390
238
 
391
- if dtype in np_float_types:
392
- for idx in range(len(outcomponents)):
393
- tape = wp.Tape()
394
- with tape:
395
- wp.launch(kernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
396
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
397
- tape.backward(loss=out)
398
- test.assertEqual(tape.gradients[input].numpy()[0], 2)
399
- tape.zero()
239
+ return args
400
240
 
401
- input = wp.array(randvals([2 * 2 + 3 * 3 + 4 * 4 + 5 * 5], dtype), requires_grad=True, device=device)
241
+ def make_vec(*args):
242
+ if wptype in wp.types.int_types:
243
+ # Cast to the correct integer type to simulate wrapping.
244
+ return tuple(wptype._type_(x).value for x in args)
402
245
 
403
- wp.launch(compkernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
404
- assert_np_equal(2 * input.numpy(), outcomponents.numpy(), tol=10 * tol)
246
+ return args
405
247
 
406
- if dtype in np_float_types:
407
- for idx in range(len(outcomponents)):
408
- tape = wp.Tape()
409
- with tape:
410
- wp.launch(compkernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
411
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
412
- tape.backward(loss=out)
413
- expectedgrads = np.zeros(len(input))
414
- expectedgrads[idx] = 2
415
- assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
416
- tape.zero()
248
+ mat_cls = wp.mat((3, 3), wptype)
249
+ vec_cls = wp.vec(3, wptype)
417
250
 
418
- wp.launch(veckernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
419
- assert_np_equal(2 * input.numpy(), outcomponents.numpy(), tol=10 * tol)
251
+ m = mat_cls(((-1, 2, 3), (4, -5, 6), (7, 8, -9)))
252
+ test.assertSequenceEqual(+m, make_mat((-1, 2, 3), (4, -5, 6), (7, 8, -9)))
253
+ test.assertSequenceEqual(-m, make_mat((1, -2, -3), (-4, 5, -6), (-7, -8, 9)))
254
+ test.assertSequenceEqual(m + mat_cls((5, 5, 5) * 3), make_mat((4, 7, 8), (9, 0, 11), (12, 13, -4)))
255
+ test.assertSequenceEqual(m - mat_cls((5, 5, 5) * 3), make_mat((-6, -3, -2), (-1, -10, 1), (2, 3, -14)))
256
+ test.assertSequenceEqual(m * vec_cls(5, 5, 5), make_vec(20, 25, 30))
257
+ test.assertSequenceEqual(m @ vec_cls(5, 5, 5), make_vec(20, 25, 30))
258
+ test.assertSequenceEqual(vec_cls(5, 5, 5) * m, make_vec(50, 25, 0))
259
+ test.assertSequenceEqual(vec_cls(5, 5, 5) @ m, make_vec(50, 25, 0))
420
260
 
421
- if dtype in np_float_types:
422
- for idx in range(len(outcomponents)):
423
- tape = wp.Tape()
424
- with tape:
425
- wp.launch(veckernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
426
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
427
- tape.backward(loss=out)
428
- expectedgrads = np.zeros(len(input))
429
- expectedgrads[idx] = 2
430
- assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
431
- tape.zero()
261
+ m = mat_cls(((2, 4, 6), (8, 10, 12), (14, 16, 18)))
262
+ test.assertSequenceEqual(m * wptype(2), make_mat((4, 8, 12), (16, 20, 24), (28, 32, 36)))
263
+ test.assertSequenceEqual(wptype(2) * m, make_mat((4, 8, 12), (16, 20, 24), (28, 32, 36)))
264
+ test.assertSequenceEqual(m / wptype(2), make_mat((1, 2, 3), (4, 5, 6), (7, 8, 9)))
265
+ test.assertSequenceEqual(wptype(5040) / m, make_mat((2520, 1260, 840), (630, 504, 420), (360, 315, 280)))
266
+ test.assertSequenceEqual(m * vec_cls(5, 5, 5), make_vec(60, 150, 240))
267
+ test.assertSequenceEqual(m @ vec_cls(5, 5, 5), make_vec(60, 150, 240))
268
+ test.assertSequenceEqual(vec_cls(5, 5, 5) * m, make_vec(120, 150, 180))
269
+ test.assertSequenceEqual(vec_cls(5, 5, 5) @ m, make_vec(120, 150, 180))
432
270
 
433
271
 
434
272
  def test_quat_constructor(test, device, dtype, register_kernels=False):
435
- np.random.seed(123)
273
+ rng = np.random.default_rng(123)
436
274
 
437
275
  tol = {
438
276
  np.float16: 1.0e-3,
@@ -481,15 +319,15 @@ def test_quat_constructor(test, device, dtype, register_kernels=False):
481
319
  return
482
320
 
483
321
  # translation:
484
- p = wp.array(np.random.randn(1, 3).astype(dtype), dtype=vec3, requires_grad=True, device=device)
322
+ p = wp.array(rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
485
323
 
486
324
  # generate a normalized quaternion for the rotation:
487
- r = np.random.randn(1, 4)
325
+ r = rng.standard_normal(size=(1, 4))
488
326
  r /= np.linalg.norm(r)
489
327
  r = wp.array(r.astype(dtype), dtype=quat, requires_grad=True, device=device)
490
328
 
491
329
  # scale:
492
- s = wp.array(np.random.randn(1, 3).astype(dtype), dtype=vec3, requires_grad=True, device=device)
330
+ s = wp.array(rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
493
331
 
494
332
  # just going to generate the matrix using the constructor, then
495
333
  # more manually, and make sure the values/gradients are the same:
@@ -530,11 +368,11 @@ def test_quat_constructor(test, device, dtype, register_kernels=False):
530
368
  idx = idx + 1
531
369
 
532
370
 
533
- def test_indexing(test, device, dtype, register_kernels=False):
534
- np.random.seed(123)
371
+ def test_negation(test, device, dtype, register_kernels=False):
372
+ rng = np.random.default_rng(123)
535
373
 
536
374
  tol = {
537
- np.float16: 1.0e-3,
375
+ np.float16: 1.0e-2,
538
376
  np.float32: 1.0e-6,
539
377
  np.float64: 1.0e-8,
540
378
  }.get(dtype, 0)
@@ -547,52 +385,57 @@ def test_indexing(test, device, dtype, register_kernels=False):
547
385
 
548
386
  output_select_kernel = get_select_kernel(wptype)
549
387
 
550
- def check_mat_indexing(
388
+ def check_mat_negation(
551
389
  m2: wp.array(dtype=mat22),
552
390
  m3: wp.array(dtype=mat33),
553
391
  m4: wp.array(dtype=mat44),
554
392
  m5: wp.array(dtype=mat55),
555
393
  outcomponents: wp.array(dtype=wptype),
556
394
  ):
395
+ mat2 = -m2[0]
396
+ mat3 = -m3[0]
397
+ mat4 = -m4[0]
398
+ mat5 = -m5[0]
399
+
557
400
  # multiply outputs by 2 so we've got something to backpropagate:
558
401
  idx = 0
559
402
  for i in range(2):
560
403
  for j in range(2):
561
- outcomponents[idx] = wptype(2) * m2[0][i, j]
404
+ outcomponents[idx] = wptype(2) * mat2[i, j]
562
405
  idx = idx + 1
563
406
 
564
407
  for i in range(3):
565
408
  for j in range(3):
566
- outcomponents[idx] = wptype(2) * m3[0][i, j]
409
+ outcomponents[idx] = wptype(2) * mat3[i, j]
567
410
  idx = idx + 1
568
411
 
569
412
  for i in range(4):
570
413
  for j in range(4):
571
- outcomponents[idx] = wptype(2) * m4[0][i, j]
414
+ outcomponents[idx] = wptype(2) * mat4[i, j]
572
415
  idx = idx + 1
573
416
 
574
417
  for i in range(5):
575
418
  for j in range(5):
576
- outcomponents[idx] = wptype(2) * m5[0][i, j]
419
+ outcomponents[idx] = wptype(2) * mat5[i, j]
577
420
  idx = idx + 1
578
421
 
579
- kernel = getkernel(check_mat_indexing, suffix=dtype.__name__)
422
+ kernel = getkernel(check_mat_negation, suffix=dtype.__name__)
580
423
 
581
424
  if register_kernels:
582
425
  return
583
426
 
584
- m2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
585
- m3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
586
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
587
- m5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
427
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
428
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
429
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
430
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
588
431
  outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
589
432
 
590
433
  wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
591
434
 
592
- assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy().reshape(-1), tol=tol)
593
- assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy().reshape(-1), tol=tol)
594
- assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy().reshape(-1), tol=tol)
595
- assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy().reshape(-1), tol=tol)
435
+ assert_np_equal(outcomponents.numpy()[:4], -2 * m2.numpy().reshape(-1), tol=tol)
436
+ assert_np_equal(outcomponents.numpy()[4:13], -2 * m3.numpy().reshape(-1), tol=tol)
437
+ assert_np_equal(outcomponents.numpy()[13:29], -2 * m4.numpy().reshape(-1), tol=tol)
438
+ assert_np_equal(outcomponents.numpy()[29:54], -2 * m5.numpy().reshape(-1), tol=tol)
596
439
 
597
440
  if dtype in np_float_types:
598
441
  idx = 0
@@ -608,291 +451,17 @@ def test_indexing(test, device, dtype, register_kernels=False):
608
451
  )
609
452
  tape.backward(loss=out)
610
453
  expectedresult = np.zeros((dim, dim), dtype=dtype)
611
- expectedresult[i, j] = 2
454
+ expectedresult[i, j] = -2
612
455
  assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
613
456
  tape.zero()
614
457
  idx = idx + 1
615
458
 
616
459
 
617
- def test_equality(test, device, dtype, register_kernels=False):
618
- np.random.seed(123)
619
-
620
- tol = {
621
- np.float16: 1.0e-3,
622
- np.float32: 1.0e-6,
623
- np.float64: 1.0e-8,
624
- }.get(dtype, 0)
625
-
626
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
627
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
628
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
629
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
630
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
631
-
632
- def check_mat_equality():
633
- wp.expect_eq(
634
- mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
635
- mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
636
- )
637
- wp.expect_neq(
638
- mat22(wptype(1.0), wptype(2.0), wptype(3.0), -wptype(4.0)),
639
- mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
640
- )
641
-
642
- wp.expect_eq(
643
- mat33(
644
- wptype(1.0),
645
- wptype(2.0),
646
- wptype(3.0),
647
- wptype(4.0),
648
- wptype(5.0),
649
- wptype(6.0),
650
- wptype(7.0),
651
- wptype(8.0),
652
- wptype(9.0),
653
- ),
654
- mat33(
655
- wptype(1.0),
656
- wptype(2.0),
657
- wptype(3.0),
658
- wptype(4.0),
659
- wptype(5.0),
660
- wptype(6.0),
661
- wptype(7.0),
662
- wptype(8.0),
663
- wptype(9.0),
664
- ),
665
- )
666
- wp.expect_neq(
667
- mat33(
668
- wptype(1.0),
669
- wptype(2.0),
670
- wptype(3.0),
671
- wptype(4.0),
672
- wptype(5.0),
673
- wptype(6.0),
674
- wptype(7.0),
675
- wptype(8.0),
676
- wptype(9.0),
677
- ),
678
- mat33(
679
- wptype(1.0),
680
- wptype(2.0),
681
- wptype(3.0),
682
- -wptype(4.0),
683
- wptype(5.0),
684
- wptype(6.0),
685
- wptype(7.0),
686
- wptype(8.0),
687
- wptype(9.0),
688
- ),
689
- )
690
-
691
- wp.expect_eq(
692
- mat44(
693
- wptype(1.0),
694
- wptype(2.0),
695
- wptype(3.0),
696
- wptype(4.0),
697
- wptype(5.0),
698
- wptype(6.0),
699
- wptype(7.0),
700
- wptype(8.0),
701
- wptype(9.0),
702
- wptype(10.0),
703
- wptype(11.0),
704
- wptype(12.0),
705
- wptype(13.0),
706
- wptype(14.0),
707
- wptype(15.0),
708
- wptype(16.0),
709
- ),
710
- mat44(
711
- wptype(1.0),
712
- wptype(2.0),
713
- wptype(3.0),
714
- wptype(4.0),
715
- wptype(5.0),
716
- wptype(6.0),
717
- wptype(7.0),
718
- wptype(8.0),
719
- wptype(9.0),
720
- wptype(10.0),
721
- wptype(11.0),
722
- wptype(12.0),
723
- wptype(13.0),
724
- wptype(14.0),
725
- wptype(15.0),
726
- wptype(16.0),
727
- ),
728
- )
729
-
730
- wp.expect_neq(
731
- mat44(
732
- wptype(1.0),
733
- wptype(2.0),
734
- wptype(3.0),
735
- wptype(4.0),
736
- wptype(5.0),
737
- wptype(6.0),
738
- wptype(7.0),
739
- wptype(8.0),
740
- wptype(9.0),
741
- wptype(10.0),
742
- wptype(11.0),
743
- wptype(12.0),
744
- wptype(13.0),
745
- wptype(14.0),
746
- wptype(15.0),
747
- wptype(16.0),
748
- ),
749
- mat44(
750
- -wptype(1.0),
751
- wptype(2.0),
752
- wptype(3.0),
753
- wptype(4.0),
754
- wptype(5.0),
755
- wptype(6.0),
756
- wptype(7.0),
757
- wptype(8.0),
758
- wptype(9.0),
759
- wptype(10.0),
760
- wptype(11.0),
761
- wptype(12.0),
762
- wptype(13.0),
763
- wptype(14.0),
764
- wptype(15.0),
765
- wptype(16.0),
766
- ),
767
- )
768
-
769
- wp.expect_eq(
770
- mat55(
771
- wptype(1.0),
772
- wptype(2.0),
773
- wptype(3.0),
774
- wptype(4.0),
775
- wptype(5.0),
776
- wptype(6.0),
777
- wptype(7.0),
778
- wptype(8.0),
779
- wptype(9.0),
780
- wptype(10.0),
781
- wptype(11.0),
782
- wptype(12.0),
783
- wptype(13.0),
784
- wptype(14.0),
785
- wptype(15.0),
786
- wptype(16.0),
787
- wptype(17.0),
788
- wptype(18.0),
789
- wptype(19.0),
790
- wptype(20.0),
791
- wptype(21.0),
792
- wptype(22.0),
793
- wptype(23.0),
794
- wptype(24.0),
795
- wptype(25.0),
796
- ),
797
- mat55(
798
- wptype(1.0),
799
- wptype(2.0),
800
- wptype(3.0),
801
- wptype(4.0),
802
- wptype(5.0),
803
- wptype(6.0),
804
- wptype(7.0),
805
- wptype(8.0),
806
- wptype(9.0),
807
- wptype(10.0),
808
- wptype(11.0),
809
- wptype(12.0),
810
- wptype(13.0),
811
- wptype(14.0),
812
- wptype(15.0),
813
- wptype(16.0),
814
- wptype(17.0),
815
- wptype(18.0),
816
- wptype(19.0),
817
- wptype(20.0),
818
- wptype(21.0),
819
- wptype(22.0),
820
- wptype(23.0),
821
- wptype(24.0),
822
- wptype(25.0),
823
- ),
824
- )
825
-
826
- wp.expect_neq(
827
- mat55(
828
- wptype(1.0),
829
- wptype(2.0),
830
- wptype(3.0),
831
- wptype(4.0),
832
- wptype(5.0),
833
- wptype(6.0),
834
- wptype(7.0),
835
- wptype(8.0),
836
- wptype(9.0),
837
- wptype(10.0),
838
- wptype(11.0),
839
- wptype(12.0),
840
- wptype(13.0),
841
- wptype(14.0),
842
- wptype(15.0),
843
- wptype(16.0),
844
- wptype(17.0),
845
- wptype(18.0),
846
- wptype(19.0),
847
- wptype(20.0),
848
- wptype(21.0),
849
- wptype(22.0),
850
- wptype(23.0),
851
- wptype(24.0),
852
- wptype(25.0),
853
- ),
854
- mat55(
855
- wptype(1.0),
856
- wptype(2.0),
857
- wptype(3.0),
858
- wptype(4.0),
859
- wptype(5.0),
860
- wptype(6.0),
861
- wptype(7.0),
862
- wptype(8.0),
863
- wptype(9.0),
864
- wptype(10.0),
865
- wptype(11.0),
866
- wptype(12.0),
867
- wptype(13.0),
868
- wptype(14.0),
869
- wptype(15.0),
870
- wptype(16.0),
871
- -wptype(17.0),
872
- wptype(18.0),
873
- wptype(19.0),
874
- wptype(20.0),
875
- wptype(21.0),
876
- wptype(22.0),
877
- wptype(23.0),
878
- wptype(24.0),
879
- wptype(25.0),
880
- ),
881
- )
882
-
883
- kernel = getkernel(check_mat_equality, suffix=dtype.__name__)
884
-
885
- if register_kernels:
886
- return
887
-
888
- wp.launch(kernel, dim=1, inputs=[], outputs=[], device=device)
889
-
890
-
891
- def test_negation(test, device, dtype, register_kernels=False):
892
- np.random.seed(123)
460
+ def test_subtraction(test, device, dtype, register_kernels=False):
461
+ rng = np.random.default_rng(123)
893
462
 
894
463
  tol = {
895
- np.float16: 1.0e-2,
464
+ np.float16: 5.0e-3,
896
465
  np.float32: 1.0e-6,
897
466
  np.float64: 1.0e-8,
898
467
  }.get(dtype, 0)
@@ -905,1401 +474,57 @@ def test_negation(test, device, dtype, register_kernels=False):
905
474
 
906
475
  output_select_kernel = get_select_kernel(wptype)
907
476
 
908
- def check_mat_negation(
909
- m2: wp.array(dtype=mat22),
910
- m3: wp.array(dtype=mat33),
911
- m4: wp.array(dtype=mat44),
912
- m5: wp.array(dtype=mat55),
477
+ def check_mat_sub(
478
+ s2: wp.array(dtype=mat22),
479
+ s3: wp.array(dtype=mat33),
480
+ s4: wp.array(dtype=mat44),
481
+ s5: wp.array(dtype=mat55),
482
+ v2: wp.array(dtype=mat22),
483
+ v3: wp.array(dtype=mat33),
484
+ v4: wp.array(dtype=mat44),
485
+ v5: wp.array(dtype=mat55),
913
486
  outcomponents: wp.array(dtype=wptype),
914
487
  ):
915
- mat2 = -m2[0]
916
- mat3 = -m3[0]
917
- mat4 = -m4[0]
918
- mat5 = -m5[0]
488
+ v2result = v2[0] - s2[0]
489
+ v3result = v3[0] - s3[0]
490
+ v4result = v4[0] - s4[0]
491
+ v5result = v5[0] - s5[0]
919
492
 
920
493
  # multiply outputs by 2 so we've got something to backpropagate:
921
494
  idx = 0
922
495
  for i in range(2):
923
496
  for j in range(2):
924
- outcomponents[idx] = wptype(2) * mat2[i, j]
497
+ outcomponents[idx] = wptype(2) * v2result[i, j]
925
498
  idx = idx + 1
926
499
 
927
500
  for i in range(3):
928
501
  for j in range(3):
929
- outcomponents[idx] = wptype(2) * mat3[i, j]
502
+ outcomponents[idx] = wptype(2) * v3result[i, j]
930
503
  idx = idx + 1
931
504
 
932
505
  for i in range(4):
933
506
  for j in range(4):
934
- outcomponents[idx] = wptype(2) * mat4[i, j]
507
+ outcomponents[idx] = wptype(2) * v4result[i, j]
935
508
  idx = idx + 1
936
509
 
937
510
  for i in range(5):
938
511
  for j in range(5):
939
- outcomponents[idx] = wptype(2) * mat5[i, j]
512
+ outcomponents[idx] = wptype(2) * v5result[i, j]
940
513
  idx = idx + 1
941
514
 
942
- kernel = getkernel(check_mat_negation, suffix=dtype.__name__)
515
+ kernel = getkernel(check_mat_sub, suffix=dtype.__name__)
943
516
 
944
517
  if register_kernels:
945
518
  return
946
519
 
947
- m2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
948
- m3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
949
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
950
- m5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
951
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
952
-
953
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
954
-
955
- assert_np_equal(outcomponents.numpy()[:4], -2 * m2.numpy().reshape(-1), tol=tol)
956
- assert_np_equal(outcomponents.numpy()[4:13], -2 * m3.numpy().reshape(-1), tol=tol)
957
- assert_np_equal(outcomponents.numpy()[13:29], -2 * m4.numpy().reshape(-1), tol=tol)
958
- assert_np_equal(outcomponents.numpy()[29:54], -2 * m5.numpy().reshape(-1), tol=tol)
959
-
960
- if dtype in np_float_types:
961
- idx = 0
962
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
963
- for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
964
- for i in range(dim):
965
- for j in range(dim):
966
- tape = wp.Tape()
967
- with tape:
968
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
969
- wp.launch(
970
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
971
- )
972
- tape.backward(loss=out)
973
- expectedresult = np.zeros((dim, dim), dtype=dtype)
974
- expectedresult[i, j] = -2
975
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
976
- tape.zero()
977
- idx = idx + 1
978
-
979
-
980
- def test_transpose(test, device, dtype, register_kernels=False):
981
- np.random.seed(123)
982
-
983
- tol = {
984
- np.float16: 1.0e-2,
985
- np.float32: 1.0e-6,
986
- np.float64: 1.0e-8,
987
- }.get(dtype, 0)
988
-
989
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
990
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
991
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
992
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
993
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
994
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
995
-
996
- output_select_kernel = get_select_kernel(wptype)
997
-
998
- def check_mat_transpose(
999
- m2: wp.array(dtype=mat22),
1000
- m3: wp.array(dtype=mat33),
1001
- m4: wp.array(dtype=mat44),
1002
- m5: wp.array(dtype=mat55),
1003
- m32: wp.array(dtype=mat32),
1004
- outcomponents: wp.array(dtype=wptype),
1005
- ):
1006
- # multiply outputs by 2 so we've got something to backpropagate:
1007
- mat2 = wptype(2) * wp.transpose(m2[0])
1008
- mat3 = wptype(2) * wp.transpose(m3[0])
1009
- mat4 = wptype(2) * wp.transpose(m4[0])
1010
- mat5 = wptype(2) * wp.transpose(m5[0])
1011
- mat32 = wptype(2) * wp.transpose(m32[0])
1012
-
1013
- idx = 0
1014
- for i in range(2):
1015
- for j in range(2):
1016
- outcomponents[idx] = mat2[i, j]
1017
- idx = idx + 1
1018
-
1019
- for i in range(3):
1020
- for j in range(3):
1021
- outcomponents[idx] = mat3[i, j]
1022
- idx = idx + 1
1023
-
1024
- for i in range(4):
1025
- for j in range(4):
1026
- outcomponents[idx] = mat4[i, j]
1027
- idx = idx + 1
1028
-
1029
- for i in range(5):
1030
- for j in range(5):
1031
- outcomponents[idx] = mat5[i, j]
1032
- idx = idx + 1
1033
-
1034
- for i in range(2):
1035
- for j in range(3):
1036
- outcomponents[idx] = mat32[i, j]
1037
- idx = idx + 1
1038
-
1039
- kernel = getkernel(check_mat_transpose, suffix=dtype.__name__)
1040
-
1041
- if register_kernels:
1042
- return
1043
-
1044
- m2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1045
- m3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1046
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1047
- m5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1048
- m32 = wp.array(randvals([1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1049
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 2 * 3, dtype=wptype, requires_grad=True, device=device)
1050
-
1051
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1052
-
1053
- assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy()[0].T.reshape(-1), tol=tol)
1054
- assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy()[0].T.reshape(-1), tol=tol)
1055
- assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy()[0].T.reshape(-1), tol=tol)
1056
- assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy()[0].T.reshape(-1), tol=tol)
1057
- assert_np_equal(outcomponents.numpy()[54:], 2 * m32.numpy()[0].T.reshape(-1), tol=tol)
1058
-
1059
- if dtype in np_float_types:
1060
- idx = 0
1061
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1062
- for input in [m2, m3, m4, m5]:
1063
- for i in range(input.dtype._shape_[0]):
1064
- for j in range(input.dtype._shape_[1]):
1065
- tape = wp.Tape()
1066
- with tape:
1067
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1068
- wp.launch(
1069
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1070
- )
1071
- tape.backward(loss=out)
1072
- expectedresult = np.zeros((input.dtype._shape_[1], input.dtype._shape_[0]), dtype=dtype)
1073
- expectedresult[j, i] = 2
1074
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
1075
- tape.zero()
1076
- idx = idx + 1
1077
-
1078
-
1079
- def test_scalar_multiplication(test, device, dtype, register_kernels=False):
1080
- np.random.seed(123)
1081
-
1082
- tol = {
1083
- np.float16: 1.0e-2,
1084
- np.float32: 1.0e-6,
1085
- np.float64: 1.0e-8,
1086
- }.get(dtype, 0)
1087
-
1088
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1089
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1090
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1091
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1092
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1093
-
1094
- output_select_kernel = get_select_kernel(wptype)
1095
-
1096
- def check_mat_scalar_mul(
1097
- s: wp.array(dtype=wptype),
1098
- m2: wp.array(dtype=mat22),
1099
- m3: wp.array(dtype=mat33),
1100
- m4: wp.array(dtype=mat44),
1101
- m5: wp.array(dtype=mat55),
1102
- outcomponents: wp.array(dtype=wptype),
1103
- outcomponents_rightmul: wp.array(dtype=wptype),
1104
- ):
1105
- m2result = s[0] * m2[0]
1106
- m3result = s[0] * m3[0]
1107
- m4result = s[0] * m4[0]
1108
- m5result = s[0] * m5[0]
1109
-
1110
- m2resultright = m2[0] * s[0]
1111
- m3resultright = m3[0] * s[0]
1112
- m4resultright = m4[0] * s[0]
1113
- m5resultright = m5[0] * s[0]
1114
-
1115
- m2result_2 = s[0] * m2[0]
1116
- m3result_2 = s[0] * m3[0]
1117
- m4result_2 = s[0] * m4[0]
1118
- m5result_2 = s[0] * m5[0]
1119
-
1120
- m2resultright_2 = m2[0] * s[0]
1121
- m3resultright_2 = m3[0] * s[0]
1122
- m4resultright_2 = m4[0] * s[0]
1123
- m5resultright_2 = m5[0] * s[0]
1124
-
1125
- # multiply outputs by 2 so we've got something to backpropagate:
1126
- idx = 0
1127
- for i in range(2):
1128
- for j in range(2):
1129
- outcomponents[idx] = wptype(2) * m2result[i, j]
1130
- outcomponents_rightmul[idx] = wptype(2) * m2resultright[i, j]
1131
- idx = idx + 1
1132
-
1133
- for i in range(3):
1134
- for j in range(3):
1135
- outcomponents[idx] = wptype(2) * m3result[i, j]
1136
- outcomponents_rightmul[idx] = wptype(2) * m3resultright[i, j]
1137
- idx = idx + 1
1138
-
1139
- for i in range(4):
1140
- for j in range(4):
1141
- outcomponents[idx] = wptype(2) * m4result[i, j]
1142
- outcomponents_rightmul[idx] = wptype(2) * m4resultright[i, j]
1143
- idx = idx + 1
1144
-
1145
- for i in range(5):
1146
- for j in range(5):
1147
- outcomponents[idx] = wptype(2) * m5result[i, j]
1148
- outcomponents_rightmul[idx] = wptype(2) * m5resultright[i, j]
1149
- idx = idx + 1
1150
-
1151
- for i in range(2):
1152
- for j in range(2):
1153
- outcomponents[idx] = wptype(2) * m2result_2[i, j]
1154
- outcomponents_rightmul[idx] = wptype(2) * m2resultright_2[i, j]
1155
- idx = idx + 1
1156
-
1157
- for i in range(3):
1158
- for j in range(3):
1159
- outcomponents[idx] = wptype(2) * m3result_2[i, j]
1160
- outcomponents_rightmul[idx] = wptype(2) * m3resultright_2[i, j]
1161
- idx = idx + 1
1162
-
1163
- for i in range(4):
1164
- for j in range(4):
1165
- outcomponents[idx] = wptype(2) * m4result_2[i, j]
1166
- outcomponents_rightmul[idx] = wptype(2) * m4resultright_2[i, j]
1167
- idx = idx + 1
1168
-
1169
- for i in range(5):
1170
- for j in range(5):
1171
- outcomponents[idx] = wptype(2) * m5result_2[i, j]
1172
- outcomponents_rightmul[idx] = wptype(2) * m5resultright_2[i, j]
1173
- idx = idx + 1
1174
-
1175
- kernel = getkernel(check_mat_scalar_mul, suffix=dtype.__name__)
1176
-
1177
- if register_kernels:
1178
- return
1179
-
1180
- s = wp.array(randvals([1], dtype), requires_grad=True, device=device)
1181
- m2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1182
- m3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1183
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1184
- m5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1185
- outcomponents = wp.zeros(2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5), dtype=wptype, requires_grad=True, device=device)
1186
- outcomponents_rightmul = wp.zeros(
1187
- 2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5), dtype=wptype, requires_grad=True, device=device
1188
- )
1189
-
1190
- wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents, outcomponents_rightmul], device=device)
1191
-
1192
- sval = s.numpy()[0]
1193
- assert_np_equal(outcomponents.numpy()[:4], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1194
- assert_np_equal(outcomponents.numpy()[4:13], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1195
- assert_np_equal(outcomponents.numpy()[13:29], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1196
- assert_np_equal(outcomponents.numpy()[29:54], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1197
-
1198
- assert_np_equal(outcomponents_rightmul.numpy()[:4], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1199
- assert_np_equal(outcomponents_rightmul.numpy()[4:13], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1200
- assert_np_equal(outcomponents_rightmul.numpy()[13:29], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1201
- assert_np_equal(outcomponents_rightmul.numpy()[29:54], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1202
-
1203
- assert_np_equal(outcomponents.numpy()[54:58], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1204
- assert_np_equal(outcomponents.numpy()[58:67], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1205
- assert_np_equal(outcomponents.numpy()[67:83], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1206
- assert_np_equal(outcomponents.numpy()[83:108], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1207
-
1208
- assert_np_equal(outcomponents_rightmul.numpy()[54:58], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1209
- assert_np_equal(outcomponents_rightmul.numpy()[58:67], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1210
- assert_np_equal(outcomponents_rightmul.numpy()[67:83], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1211
- assert_np_equal(outcomponents_rightmul.numpy()[83:108], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1212
-
1213
- if dtype in np_float_types:
1214
- idx = 0
1215
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1216
- for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
1217
- for i in range(dim):
1218
- for j in range(dim):
1219
- # test left mul gradient:
1220
- tape = wp.Tape()
1221
- with tape:
1222
- wp.launch(
1223
- kernel,
1224
- dim=1,
1225
- inputs=[s, m2, m3, m4, m5],
1226
- outputs=[outcomponents, outcomponents_rightmul],
1227
- device=device,
1228
- )
1229
- wp.launch(
1230
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1231
- )
1232
- tape.backward(loss=out)
1233
- expectedresult = np.zeros((dim, dim), dtype=dtype)
1234
- expectedresult[i, j] = 2 * sval
1235
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
1236
- assert_np_equal(tape.gradients[s].numpy()[0], 2 * input.numpy()[0, i, j], tol=10 * tol)
1237
- tape.zero()
1238
-
1239
- # test right mul gradient:
1240
- tape = wp.Tape()
1241
- with tape:
1242
- wp.launch(
1243
- kernel,
1244
- dim=1,
1245
- inputs=[s, m2, m3, m4, m5],
1246
- outputs=[outcomponents, outcomponents_rightmul],
1247
- device=device,
1248
- )
1249
- wp.launch(
1250
- output_select_kernel,
1251
- dim=1,
1252
- inputs=[outcomponents_rightmul, idx],
1253
- outputs=[out],
1254
- device=device,
1255
- )
1256
- tape.backward(loss=out)
1257
- expectedresult = np.zeros((dim, dim), dtype=dtype)
1258
- expectedresult[i, j] = 2 * sval
1259
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
1260
- assert_np_equal(tape.gradients[s].numpy()[0], 2 * input.numpy()[0, i, j], tol=10 * tol)
1261
- tape.zero()
1262
-
1263
- idx = idx + 1
1264
-
1265
-
1266
- def test_matvec_multiplication(test, device, dtype, register_kernels=False):
1267
- np.random.seed(123)
1268
-
1269
- tol = {
1270
- np.float16: 2.0e-2,
1271
- np.float32: 5.0e-6,
1272
- np.float64: 1.0e-8,
1273
- }.get(dtype, 0)
1274
-
1275
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1276
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1277
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1278
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
1279
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1280
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1281
-
1282
- vec2 = wp.types.vector(length=2, dtype=wptype)
1283
- vec3 = wp.types.vector(length=3, dtype=wptype)
1284
- vec4 = wp.types.vector(length=4, dtype=wptype)
1285
- vec5 = wp.types.vector(length=5, dtype=wptype)
1286
-
1287
- output_select_kernel = get_select_kernel(wptype)
1288
-
1289
- def check_mat_vec_mul(
1290
- v2: wp.array(dtype=vec2),
1291
- v3: wp.array(dtype=vec3),
1292
- v4: wp.array(dtype=vec4),
1293
- v5: wp.array(dtype=vec5),
1294
- v32: wp.array(dtype=vec2),
1295
- m2: wp.array(dtype=mat22),
1296
- m3: wp.array(dtype=mat33),
1297
- m4: wp.array(dtype=mat44),
1298
- m5: wp.array(dtype=mat55),
1299
- m32: wp.array(dtype=mat32),
1300
- outcomponents: wp.array(dtype=wptype),
1301
- ):
1302
- v2result = m2[0] * v2[0]
1303
- v3result = m3[0] * v3[0]
1304
- v4result = m4[0] * v4[0]
1305
- v5result = m5[0] * v5[0]
1306
- v32result = m32[0] * v32[0]
1307
- v2result_2 = m2[0] @ v2[0]
1308
- v3result_2 = m3[0] @ v3[0]
1309
- v4result_2 = m4[0] @ v4[0]
1310
- v5result_2 = m5[0] @ v5[0]
1311
- v32result_2 = m32[0] @ v32[0]
1312
-
1313
- idx = 0
1314
-
1315
- # multiply outputs by 2 so we've got something to backpropagate:
1316
- for i in range(2):
1317
- outcomponents[idx] = wptype(2) * v2result[i]
1318
- idx = idx + 1
1319
-
1320
- for i in range(3):
1321
- outcomponents[idx] = wptype(2) * v3result[i]
1322
- idx = idx + 1
1323
-
1324
- for i in range(4):
1325
- outcomponents[idx] = wptype(2) * v4result[i]
1326
- idx = idx + 1
1327
-
1328
- for i in range(5):
1329
- outcomponents[idx] = wptype(2) * v5result[i]
1330
- idx = idx + 1
1331
-
1332
- for i in range(3):
1333
- outcomponents[idx] = wptype(2) * v32result[i]
1334
- idx = idx + 1
1335
-
1336
- for i in range(2):
1337
- outcomponents[idx] = wptype(2) * v2result_2[i]
1338
- idx = idx + 1
1339
-
1340
- for i in range(3):
1341
- outcomponents[idx] = wptype(2) * v3result_2[i]
1342
- idx = idx + 1
1343
-
1344
- for i in range(4):
1345
- outcomponents[idx] = wptype(2) * v4result_2[i]
1346
- idx = idx + 1
1347
-
1348
- for i in range(5):
1349
- outcomponents[idx] = wptype(2) * v5result_2[i]
1350
- idx = idx + 1
1351
-
1352
- for i in range(3):
1353
- outcomponents[idx] = wptype(2) * v32result_2[i]
1354
- idx = idx + 1
1355
-
1356
- kernel = getkernel(check_mat_vec_mul, suffix=dtype.__name__)
1357
-
1358
- if register_kernels:
1359
- return
1360
-
1361
- v2 = wp.array(randvals([1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1362
- v3 = wp.array(randvals([1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1363
- v4 = wp.array(randvals([1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
1364
- v5 = wp.array(randvals([1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
1365
- v32 = wp.array(randvals([1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1366
- m2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1367
- m3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1368
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1369
- m5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1370
- m32 = wp.array(randvals([1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1371
- outcomponents = wp.zeros(2 * (2 + 3 + 4 + 5 + 3), dtype=wptype, requires_grad=True, device=device)
1372
-
1373
- wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1374
-
1375
- assert_np_equal(outcomponents.numpy()[:2], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1376
- assert_np_equal(outcomponents.numpy()[2:5], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1377
- assert_np_equal(outcomponents.numpy()[5:9], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=5 * tol)
1378
- assert_np_equal(outcomponents.numpy()[9:14], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=5 * tol)
1379
- assert_np_equal(outcomponents.numpy()[14:17], 2 * np.matmul(m32.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1380
- assert_np_equal(outcomponents.numpy()[17:19], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1381
- assert_np_equal(outcomponents.numpy()[19:22], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1382
- assert_np_equal(outcomponents.numpy()[22:26], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=5 * tol)
1383
- assert_np_equal(outcomponents.numpy()[26:31], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=5 * tol)
1384
- assert_np_equal(outcomponents.numpy()[31:34], 2 * np.matmul(m32.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1385
-
1386
- if dtype in np_float_types:
1387
- idx = 0
1388
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1389
- for dim, invec, inmat in [(2, v2, m2), (3, v3, m3), (4, v4, m4), (5, v5, m5), (3, v32, m32)]:
1390
- for i in range(dim):
1391
- tape = wp.Tape()
1392
- with tape:
1393
- wp.launch(
1394
- kernel,
1395
- dim=1,
1396
- inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32],
1397
- outputs=[outcomponents],
1398
- device=device,
1399
- )
1400
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1401
- tape.backward(loss=out)
1402
-
1403
- assert_np_equal(tape.gradients[invec].numpy()[0], 2 * inmat.numpy()[0, i, :], tol=2 * tol)
1404
- expectedresult = np.zeros(inmat.dtype._shape_, dtype=dtype)
1405
- expectedresult[i, :] = 2 * invec.numpy()[0]
1406
- assert_np_equal(tape.gradients[inmat].numpy()[0], expectedresult, tol=2 * tol)
1407
-
1408
- tape.zero()
1409
-
1410
- idx = idx + 1
1411
-
1412
-
1413
- def test_matmat_multiplication(test, device, dtype, register_kernels=False):
1414
- np.random.seed(123)
1415
-
1416
- tol = {
1417
- np.float16: 2.0e-2,
1418
- np.float32: 5.0e-6,
1419
- np.float64: 1.0e-8,
1420
- }.get(dtype, 0)
1421
-
1422
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1423
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1424
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1425
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
1426
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1427
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1428
-
1429
- output_select_kernel = get_select_kernel(wptype)
1430
-
1431
- def check_mat_mat_mul(
1432
- a2: wp.array(dtype=mat22),
1433
- a3: wp.array(dtype=mat33),
1434
- a4: wp.array(dtype=mat44),
1435
- a5: wp.array(dtype=mat55),
1436
- a32: wp.array(dtype=mat32),
1437
- b2: wp.array(dtype=mat22),
1438
- b3: wp.array(dtype=mat33),
1439
- b4: wp.array(dtype=mat44),
1440
- b5: wp.array(dtype=mat55),
1441
- b32: wp.array(dtype=mat32),
1442
- outcomponents: wp.array(dtype=wptype),
1443
- ):
1444
- c2result = b2[0] * a2[0]
1445
- c3result = b3[0] * a3[0]
1446
- c4result = b4[0] * a4[0]
1447
- c5result = b5[0] * a5[0]
1448
- c32result = b32[0] * a2[0]
1449
- c32result2 = b3[0] * a32[0]
1450
- c2result_2 = b2[0] @ a2[0]
1451
- c3result_2 = b3[0] @ a3[0]
1452
- c4result_2 = b4[0] @ a4[0]
1453
- c5result_2 = b5[0] @ a5[0]
1454
- c32result_2 = b32[0] @ a2[0]
1455
- c32result2_2 = b3[0] @ a32[0]
1456
-
1457
- # multiply outputs by 2 so we've got something to backpropagate:
1458
- idx = 0
1459
- for i in range(2):
1460
- for j in range(2):
1461
- outcomponents[idx] = wptype(2) * c2result[i, j]
1462
- idx = idx + 1
1463
-
1464
- for i in range(3):
1465
- for j in range(3):
1466
- outcomponents[idx] = wptype(2) * c3result[i, j]
1467
- idx = idx + 1
1468
-
1469
- for i in range(4):
1470
- for j in range(4):
1471
- outcomponents[idx] = wptype(2) * c4result[i, j]
1472
- idx = idx + 1
1473
-
1474
- for i in range(5):
1475
- for j in range(5):
1476
- outcomponents[idx] = wptype(2) * c5result[i, j]
1477
- idx = idx + 1
1478
-
1479
- for i in range(3):
1480
- for j in range(2):
1481
- outcomponents[idx] = wptype(2) * c32result[i, j]
1482
- idx = idx + 1
1483
-
1484
- for i in range(3):
1485
- for j in range(2):
1486
- outcomponents[idx] = wptype(2) * c32result2[i, j]
1487
- idx = idx + 1
1488
-
1489
- for i in range(2):
1490
- for j in range(2):
1491
- outcomponents[idx] = wptype(2) * c2result_2[i, j]
1492
- idx = idx + 1
1493
-
1494
- for i in range(3):
1495
- for j in range(3):
1496
- outcomponents[idx] = wptype(2) * c3result_2[i, j]
1497
- idx = idx + 1
1498
-
1499
- for i in range(4):
1500
- for j in range(4):
1501
- outcomponents[idx] = wptype(2) * c4result_2[i, j]
1502
- idx = idx + 1
1503
-
1504
- for i in range(5):
1505
- for j in range(5):
1506
- outcomponents[idx] = wptype(2) * c5result_2[i, j]
1507
- idx = idx + 1
1508
-
1509
- for i in range(3):
1510
- for j in range(2):
1511
- outcomponents[idx] = wptype(2) * c32result_2[i, j]
1512
- idx = idx + 1
1513
-
1514
- for i in range(3):
1515
- for j in range(2):
1516
- outcomponents[idx] = wptype(2) * c32result2_2[i, j]
1517
- idx = idx + 1
1518
-
1519
- kernel = getkernel(check_mat_mat_mul, suffix=dtype.__name__)
1520
-
1521
- if register_kernels:
1522
- return
1523
-
1524
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1525
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1526
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1527
- v5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1528
- v32 = wp.array(randvals([1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1529
- m2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1530
- m3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1531
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1532
- m5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1533
- m32 = wp.array(randvals([1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1534
- outcomponents = wp.zeros(
1535
- 2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2 + 3 * 2), dtype=wptype, requires_grad=True, device=device
1536
- )
1537
-
1538
- wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1539
-
1540
- assert_np_equal(outcomponents.numpy()[:4], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1541
- assert_np_equal(outcomponents.numpy()[4:13], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1542
- assert_np_equal(outcomponents.numpy()[13:29], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=2 * tol)
1543
- assert_np_equal(outcomponents.numpy()[29:54], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=10 * tol)
1544
- assert_np_equal(outcomponents.numpy()[54:60], 2 * np.matmul(m32.numpy()[0], v2.numpy()[0]), tol=5 * tol)
1545
- assert_np_equal(outcomponents.numpy()[60:66], 2 * np.matmul(m3.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1546
- assert_np_equal(outcomponents.numpy()[66:70], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1547
- assert_np_equal(outcomponents.numpy()[70:79], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1548
- assert_np_equal(outcomponents.numpy()[79:95], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=2 * tol)
1549
- assert_np_equal(outcomponents.numpy()[95:120], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=10 * tol)
1550
- assert_np_equal(outcomponents.numpy()[120:126], 2 * np.matmul(m32.numpy()[0], v2.numpy()[0]), tol=5 * tol)
1551
- assert_np_equal(outcomponents.numpy()[126:132], 2 * np.matmul(m3.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1552
-
1553
- if dtype in np_float_types:
1554
- idx = 0
1555
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1556
- for v, m in [(v2, m2), (v3, m3), (v4, m4), (v5, m5), (v2, m32), (v32, m3)]:
1557
- rows, cols = m.dtype._shape_[0], v.dtype._shape_[1]
1558
- for i in range(rows):
1559
- for j in range(cols):
1560
- tape = wp.Tape()
1561
- with tape:
1562
- wp.launch(
1563
- kernel,
1564
- dim=1,
1565
- inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32],
1566
- outputs=[outcomponents],
1567
- device=device,
1568
- )
1569
- wp.launch(
1570
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1571
- )
1572
- tape.backward(loss=out)
1573
-
1574
- expected = np.zeros(v.dtype._shape_, dtype=dtype)
1575
- expected[:, j] = 2 * m.numpy()[0, i, :]
1576
- assert_np_equal(tape.gradients[v].numpy()[0], expected, tol=10 * tol)
1577
-
1578
- expected = np.zeros(m.dtype._shape_, dtype=dtype)
1579
- expected[i, :] = 2 * v.numpy()[0, :, j]
1580
- assert_np_equal(tape.gradients[m].numpy()[0], expected, tol=10 * tol)
1581
-
1582
- tape.zero()
1583
- idx = idx + 1
1584
-
1585
-
1586
- def test_cw_multiplication(test, device, dtype, register_kernels=False):
1587
- np.random.seed(123)
1588
-
1589
- tol = {
1590
- np.float16: 5.0e-2,
1591
- np.float32: 1.0e-6,
1592
- np.float64: 1.0e-8,
1593
- }.get(dtype, 0)
1594
-
1595
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1596
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1597
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1598
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1599
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1600
-
1601
- output_select_kernel = get_select_kernel(wptype)
1602
-
1603
- def check_mat_cw_mul(
1604
- s2: wp.array(dtype=mat22),
1605
- s3: wp.array(dtype=mat33),
1606
- s4: wp.array(dtype=mat44),
1607
- s5: wp.array(dtype=mat55),
1608
- v2: wp.array(dtype=mat22),
1609
- v3: wp.array(dtype=mat33),
1610
- v4: wp.array(dtype=mat44),
1611
- v5: wp.array(dtype=mat55),
1612
- outcomponents: wp.array(dtype=wptype),
1613
- ):
1614
- v2result = wptype(2) * wp.cw_mul(v2[0], s2[0])
1615
- v3result = wptype(2) * wp.cw_mul(v3[0], s3[0])
1616
- v4result = wptype(2) * wp.cw_mul(v4[0], s4[0])
1617
- v5result = wptype(2) * wp.cw_mul(v5[0], s5[0])
1618
-
1619
- # multiply outputs by 2 so we've got something to backpropagate:
1620
- idx = 0
1621
- for i in range(2):
1622
- for j in range(2):
1623
- outcomponents[idx] = v2result[i, j]
1624
- idx = idx + 1
1625
-
1626
- for i in range(3):
1627
- for j in range(3):
1628
- outcomponents[idx] = v3result[i, j]
1629
- idx = idx + 1
1630
-
1631
- for i in range(4):
1632
- for j in range(4):
1633
- outcomponents[idx] = v4result[i, j]
1634
- idx = idx + 1
1635
-
1636
- for i in range(5):
1637
- for j in range(5):
1638
- outcomponents[idx] = v5result[i, j]
1639
- idx = idx + 1
1640
-
1641
- kernel = getkernel(check_mat_cw_mul, suffix=dtype.__name__)
1642
-
1643
- if register_kernels:
1644
- return
1645
-
1646
- s2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1647
- s3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1648
- s4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1649
- s5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1650
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1651
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1652
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1653
- v5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1654
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
1655
-
1656
- wp.launch(
1657
- kernel,
1658
- dim=1,
1659
- inputs=[
1660
- s2,
1661
- s3,
1662
- s4,
1663
- s5,
1664
- v2,
1665
- v3,
1666
- v4,
1667
- v5,
1668
- ],
1669
- outputs=[outcomponents],
1670
- device=device,
1671
- )
1672
-
1673
- assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() * s2.numpy()).reshape(-1), tol=50 * tol)
1674
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() * s3.numpy()).reshape(-1), tol=50 * tol)
1675
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() * s4.numpy()).reshape(-1), tol=50 * tol)
1676
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() * s5.numpy()).reshape(-1), tol=50 * tol)
1677
-
1678
- if dtype in np_float_types:
1679
- idx = 0
1680
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1681
- for dim, in1, in2 in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
1682
- for i in range(dim):
1683
- for j in range(dim):
1684
- tape = wp.Tape()
1685
- with tape:
1686
- wp.launch(
1687
- kernel,
1688
- dim=1,
1689
- inputs=[
1690
- s2,
1691
- s3,
1692
- s4,
1693
- s5,
1694
- v2,
1695
- v3,
1696
- v4,
1697
- v5,
1698
- ],
1699
- outputs=[outcomponents],
1700
- device=device,
1701
- )
1702
- wp.launch(
1703
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1704
- )
1705
- tape.backward(loss=out)
1706
- expectedresult = np.zeros((dim, dim), dtype=dtype)
1707
- expectedresult[i, j] = 2 * in1.numpy()[0][i, j]
1708
- assert_np_equal(tape.gradients[in2].numpy()[0], expectedresult, tol=5 * tol)
1709
- expectedresult[i, j] = 2 * in2.numpy()[0][i, j]
1710
- assert_np_equal(tape.gradients[in1].numpy()[0], expectedresult, tol=5 * tol)
1711
- tape.zero()
1712
-
1713
- idx = idx + 1
1714
-
1715
-
1716
- def test_cw_division(test, device, dtype, register_kernels=False):
1717
- np.random.seed(123)
1718
-
1719
- tol = {
1720
- np.float16: 1.0e-2,
1721
- np.float32: 1.0e-6,
1722
- np.float64: 1.0e-8,
1723
- }.get(dtype, 0)
1724
-
1725
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1726
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1727
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1728
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1729
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1730
-
1731
- output_select_kernel = get_select_kernel(wptype)
1732
-
1733
- def check_mat_cw_div(
1734
- s2: wp.array(dtype=mat22),
1735
- s3: wp.array(dtype=mat33),
1736
- s4: wp.array(dtype=mat44),
1737
- s5: wp.array(dtype=mat55),
1738
- v2: wp.array(dtype=mat22),
1739
- v3: wp.array(dtype=mat33),
1740
- v4: wp.array(dtype=mat44),
1741
- v5: wp.array(dtype=mat55),
1742
- outcomponents: wp.array(dtype=wptype),
1743
- ):
1744
- v2result = wptype(2) * wp.cw_div(v2[0], s2[0])
1745
- v3result = wptype(2) * wp.cw_div(v3[0], s3[0])
1746
- v4result = wptype(2) * wp.cw_div(v4[0], s4[0])
1747
- v5result = wptype(2) * wp.cw_div(v5[0], s5[0])
1748
-
1749
- # multiply outputs by 2 so we've got something to backpropagate:
1750
- idx = 0
1751
- for i in range(2):
1752
- for j in range(2):
1753
- outcomponents[idx] = v2result[i, j]
1754
- idx = idx + 1
1755
-
1756
- for i in range(3):
1757
- for j in range(3):
1758
- outcomponents[idx] = v3result[i, j]
1759
- idx = idx + 1
1760
-
1761
- for i in range(4):
1762
- for j in range(4):
1763
- outcomponents[idx] = v4result[i, j]
1764
- idx = idx + 1
1765
-
1766
- for i in range(5):
1767
- for j in range(5):
1768
- outcomponents[idx] = v5result[i, j]
1769
- idx = idx + 1
1770
-
1771
- kernel = getkernel(check_mat_cw_div, suffix=dtype.__name__)
1772
-
1773
- if register_kernels:
1774
- return
1775
-
1776
- s2 = randvals([1, 2, 2], dtype)
1777
- s3 = randvals([1, 3, 3], dtype)
1778
- s4 = randvals([1, 4, 4], dtype)
1779
- s5 = randvals([1, 5, 5], dtype)
1780
-
1781
- # set denominators to 1 if their magnitudes are small
1782
- # to prevent divide by zero, or overflows if we're testing
1783
- # float16:
1784
- s2[np.abs(s2) < 1.0e-2] = 1
1785
- s3[np.abs(s3) < 1.0e-2] = 1
1786
- s4[np.abs(s4) < 1.0e-2] = 1
1787
- s5[np.abs(s5) < 1.0e-2] = 1
1788
-
1789
- s2 = wp.array(s2, dtype=mat22, requires_grad=True, device=device)
1790
- s3 = wp.array(s3, dtype=mat33, requires_grad=True, device=device)
1791
- s4 = wp.array(s4, dtype=mat44, requires_grad=True, device=device)
1792
- s5 = wp.array(s5, dtype=mat55, requires_grad=True, device=device)
1793
-
1794
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1795
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1796
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1797
- v5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1798
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
1799
-
1800
- wp.launch(
1801
- kernel,
1802
- dim=1,
1803
- inputs=[
1804
- s2,
1805
- s3,
1806
- s4,
1807
- s5,
1808
- v2,
1809
- v3,
1810
- v4,
1811
- v5,
1812
- ],
1813
- outputs=[outcomponents],
1814
- device=device,
1815
- )
1816
-
1817
- if dtype in np_float_types:
1818
- assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() / s2.numpy()).reshape(-1), tol=50 * tol)
1819
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() / s3.numpy()).reshape(-1), tol=50 * tol)
1820
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() / s4.numpy()).reshape(-1), tol=50 * tol)
1821
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() / s5.numpy()).reshape(-1), tol=50 * tol)
1822
- else:
1823
- assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() // s2.numpy()).reshape(-1), tol=50 * tol)
1824
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() // s3.numpy()).reshape(-1), tol=50 * tol)
1825
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() // s4.numpy()).reshape(-1), tol=50 * tol)
1826
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() // s5.numpy()).reshape(-1), tol=50 * tol)
1827
-
1828
- if dtype in np_float_types:
1829
- idx = 0
1830
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1831
- for dim, s, v in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
1832
- for i in range(dim):
1833
- for j in range(dim):
1834
- tape = wp.Tape()
1835
- with tape:
1836
- wp.launch(
1837
- kernel,
1838
- dim=1,
1839
- inputs=[
1840
- s2,
1841
- s3,
1842
- s4,
1843
- s5,
1844
- v2,
1845
- v3,
1846
- v4,
1847
- v5,
1848
- ],
1849
- outputs=[outcomponents],
1850
- device=device,
1851
- )
1852
- wp.launch(
1853
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1854
- )
1855
- tape.backward(loss=out)
1856
-
1857
- # y = v/s
1858
- # dy/dv = 1.0/s
1859
- # dy/ds = -v/s^2
1860
-
1861
- expectedresult = np.zeros((dim, dim), dtype=dtype)
1862
- expectedresult[i, j] = 2.0 / (s.numpy()[0, i, j])
1863
- assert_np_equal(tape.gradients[v].numpy()[0], expectedresult, tol=50 * tol)
1864
- expectedresult[i, j] = -2.0 * v.numpy()[0, i, j] / (s.numpy()[0, i, j] ** 2)
1865
- assert_np_equal(
1866
- tape.gradients[s].numpy()[0], expectedresult, tol=abs(outcomponents.numpy()[idx]) * 50 * tol
1867
- )
1868
- tape.zero()
1869
-
1870
- idx = idx + 1
1871
-
1872
-
1873
- def test_outer_product(test, device, dtype, register_kernels=False):
1874
- np.random.seed(123)
1875
-
1876
- tol = {
1877
- np.float16: 5.0e-3,
1878
- np.float32: 1.0e-6,
1879
- np.float64: 1.0e-8,
1880
- }.get(dtype, 0)
1881
-
1882
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1883
- vec2 = wp.types.vector(length=2, dtype=wptype)
1884
- vec3 = wp.types.vector(length=3, dtype=wptype)
1885
- vec4 = wp.types.vector(length=4, dtype=wptype)
1886
- vec5 = wp.types.vector(length=5, dtype=wptype)
1887
-
1888
- output_select_kernel = get_select_kernel(wptype)
1889
-
1890
- def check_mat_outer_product(
1891
- s2: wp.array(dtype=vec2),
1892
- s3: wp.array(dtype=vec3),
1893
- s4: wp.array(dtype=vec4),
1894
- s5: wp.array(dtype=vec5),
1895
- v2: wp.array(dtype=vec2),
1896
- v3: wp.array(dtype=vec3),
1897
- v4: wp.array(dtype=vec4),
1898
- v5: wp.array(dtype=vec5),
1899
- outcomponents: wp.array(dtype=wptype),
1900
- ):
1901
- m22result = wptype(2) * wp.outer(s2[0], v2[0])
1902
- m33result = wptype(2) * wp.outer(s3[0], v3[0])
1903
- m44result = wptype(2) * wp.outer(s4[0], v4[0])
1904
- m55result = wptype(2) * wp.outer(s5[0], v5[0])
1905
- m25result = wptype(2) * wp.outer(s2[0], v5[0])
1906
-
1907
- # multiply outputs by 2 so we've got something to backpropagate:
1908
- idx = 0
1909
- for i in range(2):
1910
- for j in range(2):
1911
- outcomponents[idx] = m22result[i, j]
1912
- idx = idx + 1
1913
-
1914
- for i in range(3):
1915
- for j in range(3):
1916
- outcomponents[idx] = m33result[i, j]
1917
- idx = idx + 1
1918
-
1919
- for i in range(4):
1920
- for j in range(4):
1921
- outcomponents[idx] = m44result[i, j]
1922
- idx = idx + 1
1923
-
1924
- for i in range(5):
1925
- for j in range(5):
1926
- outcomponents[idx] = m55result[i, j]
1927
- idx = idx + 1
1928
-
1929
- for i in range(2):
1930
- for j in range(5):
1931
- outcomponents[idx] = m25result[i, j]
1932
- idx = idx + 1
1933
-
1934
- kernel = getkernel(check_mat_outer_product, suffix=dtype.__name__)
1935
-
1936
- if register_kernels:
1937
- return
1938
-
1939
- s2 = wp.array(randvals([1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1940
- s3 = wp.array(randvals([1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1941
- s4 = wp.array(randvals([1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
1942
- s5 = wp.array(randvals([1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
1943
- v2 = wp.array(randvals([1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1944
- v3 = wp.array(randvals([1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1945
- v4 = wp.array(randvals([1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
1946
- v5 = wp.array(randvals([1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
1947
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 2 * 5, dtype=wptype, requires_grad=True, device=device)
1948
-
1949
- wp.launch(kernel, dim=1, inputs=[s2, s3, s4, s5, v2, v3, v4, v5], outputs=[outcomponents], device=device)
1950
-
1951
- assert_np_equal(outcomponents.numpy()[:4], 2 * s2.numpy()[0, :, None] * v2.numpy()[0, None, :], tol=tol)
1952
- assert_np_equal(outcomponents.numpy()[4:13], 2 * s3.numpy()[0, :, None] * v3.numpy()[0, None, :], tol=10 * tol)
1953
- assert_np_equal(outcomponents.numpy()[13:29], 2 * s4.numpy()[0, :, None] * v4.numpy()[0, None, :], tol=10 * tol)
1954
- assert_np_equal(outcomponents.numpy()[29:54], 2 * s5.numpy()[0, :, None] * v5.numpy()[0, None, :], tol=10 * tol)
1955
- assert_np_equal(outcomponents.numpy()[54:], 2 * s2.numpy()[0, :, None] * v5.numpy()[0, None, :], tol=10 * tol)
1956
-
1957
- if dtype in np_float_types:
1958
- idx = 0
1959
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1960
- for s, v in [(s2, v2), (s3, v3), (s4, v4), (s5, v5), (s2, v5)]:
1961
- rows = s.dtype._length_
1962
- cols = v.dtype._length_
1963
- for i in range(rows):
1964
- for j in range(cols):
1965
- tape = wp.Tape()
1966
- with tape:
1967
- wp.launch(
1968
- kernel,
1969
- dim=1,
1970
- inputs=[
1971
- s2,
1972
- s3,
1973
- s4,
1974
- s5,
1975
- v2,
1976
- v3,
1977
- v4,
1978
- v5,
1979
- ],
1980
- outputs=[outcomponents],
1981
- device=device,
1982
- )
1983
- wp.launch(
1984
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1985
- )
1986
- tape.backward(loss=out)
1987
-
1988
- # this component's gonna be s_i * v_j, so its s gradient is gonna be nozero
1989
- # at the ith component and its v gradient will be nonzero at the jth component:
1990
-
1991
- expectedresult = np.zeros((rows), dtype=dtype)
1992
- expectedresult[i] = 2 * v.numpy()[0, j]
1993
- assert_np_equal(tape.gradients[s].numpy()[0], expectedresult, tol=10 * tol)
1994
-
1995
- expectedresult = np.zeros((cols), dtype=dtype)
1996
- expectedresult[j] = 2 * s.numpy()[0, i]
1997
- assert_np_equal(tape.gradients[v].numpy()[0], expectedresult, tol=10 * tol)
1998
- tape.zero()
1999
-
2000
- idx = idx + 1
2001
-
2002
-
2003
- def test_scalar_division(test, device, dtype, register_kernels=False):
2004
- np.random.seed(123)
2005
-
2006
- tol = {
2007
- np.float16: 1.0e-2,
2008
- np.float32: 1.0e-6,
2009
- np.float64: 1.0e-8,
2010
- }.get(dtype, 0)
2011
-
2012
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2013
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2014
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2015
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2016
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2017
-
2018
- output_select_kernel = get_select_kernel(wptype)
2019
-
2020
- def check_mat_scalar_div(
2021
- s: wp.array(dtype=wptype),
2022
- m2: wp.array(dtype=mat22),
2023
- m3: wp.array(dtype=mat33),
2024
- m4: wp.array(dtype=mat44),
2025
- m5: wp.array(dtype=mat55),
2026
- outcomponents: wp.array(dtype=wptype),
2027
- ):
2028
- m2result = m2[0] / s[0]
2029
- m3result = m3[0] / s[0]
2030
- m4result = m4[0] / s[0]
2031
- m5result = m5[0] / s[0]
2032
-
2033
- # multiply outputs by 2 so we've got something to backpropagate:
2034
- idx = 0
2035
- for i in range(2):
2036
- for j in range(2):
2037
- outcomponents[idx] = wptype(2) * m2result[i, j]
2038
- idx = idx + 1
2039
-
2040
- for i in range(3):
2041
- for j in range(3):
2042
- outcomponents[idx] = wptype(2) * m3result[i, j]
2043
- idx = idx + 1
2044
-
2045
- for i in range(4):
2046
- for j in range(4):
2047
- outcomponents[idx] = wptype(2) * m4result[i, j]
2048
- idx = idx + 1
2049
-
2050
- for i in range(5):
2051
- for j in range(5):
2052
- outcomponents[idx] = wptype(2) * m5result[i, j]
2053
- idx = idx + 1
2054
-
2055
- kernel = getkernel(check_mat_scalar_div, suffix=dtype.__name__)
2056
-
2057
- if register_kernels:
2058
- return
2059
-
2060
- s = wp.array(randvals([1], dtype), requires_grad=True, device=device)
2061
- m2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2062
- m3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2063
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2064
- m5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2065
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
2066
-
2067
- wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents], device=device)
2068
-
2069
- sval = s.numpy()[0]
2070
- if dtype in np_float_types:
2071
- assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy().reshape(-1) / sval, tol=tol)
2072
- assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy().reshape(-1) / sval, tol=10 * tol)
2073
- assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy().reshape(-1) / sval, tol=10 * tol)
2074
- assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy().reshape(-1) / sval, tol=10 * tol)
2075
- else:
2076
- assert_np_equal(outcomponents.numpy()[:4], 2 * (m2.numpy().reshape(-1) // sval), tol=tol)
2077
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (m3.numpy().reshape(-1) // sval), tol=10 * tol)
2078
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (m4.numpy().reshape(-1) // sval), tol=10 * tol)
2079
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (m5.numpy().reshape(-1) // sval), tol=10 * tol)
2080
-
2081
- if dtype in np_float_types:
2082
- idx = 0
2083
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2084
- for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
2085
- for i in range(dim):
2086
- for j in range(dim):
2087
- tape = wp.Tape()
2088
- with tape:
2089
- wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents], device=device)
2090
- wp.launch(
2091
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2092
- )
2093
- tape.backward(loss=out)
2094
- expectedresult = np.zeros((dim, dim), dtype=dtype)
2095
- expectedresult[i, j] = 2.0 / sval
2096
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
2097
- assert_np_equal(
2098
- tape.gradients[s].numpy()[0], -2 * input.numpy()[0, i, j] / (sval * sval), tol=10 * tol
2099
- )
2100
- tape.zero()
2101
-
2102
- idx = idx + 1
2103
-
2104
-
2105
- def test_addition(test, device, dtype, register_kernels=False):
2106
- np.random.seed(123)
2107
-
2108
- tol = {
2109
- np.float16: 2.0e-2,
2110
- np.float32: 5.0e-6,
2111
- np.float64: 1.0e-8,
2112
- }.get(dtype, 0)
2113
-
2114
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2115
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2116
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2117
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2118
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2119
-
2120
- output_select_kernel = get_select_kernel(wptype)
2121
-
2122
- def check_mat_add(
2123
- s2: wp.array(dtype=mat22),
2124
- s3: wp.array(dtype=mat33),
2125
- s4: wp.array(dtype=mat44),
2126
- s5: wp.array(dtype=mat55),
2127
- v2: wp.array(dtype=mat22),
2128
- v3: wp.array(dtype=mat33),
2129
- v4: wp.array(dtype=mat44),
2130
- v5: wp.array(dtype=mat55),
2131
- outcomponents: wp.array(dtype=wptype),
2132
- ):
2133
- v2result = v2[0] + s2[0]
2134
- v3result = v3[0] + s3[0]
2135
- v4result = v4[0] + s4[0]
2136
- v5result = v5[0] + s5[0]
2137
-
2138
- # multiply outputs by 2 so we've got something to backpropagate:
2139
- idx = 0
2140
- for i in range(2):
2141
- for j in range(2):
2142
- outcomponents[idx] = wptype(2) * v2result[i, j]
2143
- idx = idx + 1
2144
-
2145
- for i in range(3):
2146
- for j in range(3):
2147
- outcomponents[idx] = wptype(2) * v3result[i, j]
2148
- idx = idx + 1
2149
-
2150
- for i in range(4):
2151
- for j in range(4):
2152
- outcomponents[idx] = wptype(2) * v4result[i, j]
2153
- idx = idx + 1
2154
-
2155
- for i in range(5):
2156
- for j in range(5):
2157
- outcomponents[idx] = wptype(2) * v5result[i, j]
2158
- idx = idx + 1
2159
-
2160
- kernel = getkernel(check_mat_add, suffix=dtype.__name__)
2161
-
2162
- if register_kernels:
2163
- return
2164
-
2165
- s2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2166
- s3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2167
- s4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2168
- s5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2169
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2170
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2171
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2172
- v5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2173
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
2174
-
2175
- wp.launch(
2176
- kernel,
2177
- dim=1,
2178
- inputs=[
2179
- s2,
2180
- s3,
2181
- s4,
2182
- s5,
2183
- v2,
2184
- v3,
2185
- v4,
2186
- v5,
2187
- ],
2188
- outputs=[outcomponents],
2189
- device=device,
2190
- )
2191
-
2192
- assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() + s2.numpy()).reshape(-1), tol=tol)
2193
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() + s3.numpy()).reshape(-1), tol=tol)
2194
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() + s4.numpy()).reshape(-1), tol=tol)
2195
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() + s5.numpy()).reshape(-1), tol=tol)
2196
-
2197
- if dtype in np_float_types:
2198
- idx = 0
2199
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2200
- for dim, in1, in2 in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
2201
- for i in range(dim):
2202
- for j in range(dim):
2203
- tape = wp.Tape()
2204
- with tape:
2205
- wp.launch(
2206
- kernel,
2207
- dim=1,
2208
- inputs=[
2209
- s2,
2210
- s3,
2211
- s4,
2212
- s5,
2213
- v2,
2214
- v3,
2215
- v4,
2216
- v5,
2217
- ],
2218
- outputs=[outcomponents],
2219
- device=device,
2220
- )
2221
- wp.launch(
2222
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2223
- )
2224
- tape.backward(loss=out)
2225
- expectedresult = np.zeros((dim, dim), dtype=dtype)
2226
- expectedresult[i, j] = 2
2227
- assert_np_equal(tape.gradients[in2].numpy()[0], expectedresult, tol=10 * tol)
2228
- expectedresult[i, j] = 2
2229
- assert_np_equal(tape.gradients[in1].numpy()[0], expectedresult, tol=10 * tol)
2230
- tape.zero()
2231
-
2232
- idx = idx + 1
2233
-
2234
-
2235
- def test_subtraction(test, device, dtype, register_kernels=False):
2236
- np.random.seed(123)
2237
-
2238
- tol = {
2239
- np.float16: 5.0e-3,
2240
- np.float32: 1.0e-6,
2241
- np.float64: 1.0e-8,
2242
- }.get(dtype, 0)
2243
-
2244
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2245
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2246
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2247
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2248
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2249
-
2250
- output_select_kernel = get_select_kernel(wptype)
2251
-
2252
- def check_mat_sub(
2253
- s2: wp.array(dtype=mat22),
2254
- s3: wp.array(dtype=mat33),
2255
- s4: wp.array(dtype=mat44),
2256
- s5: wp.array(dtype=mat55),
2257
- v2: wp.array(dtype=mat22),
2258
- v3: wp.array(dtype=mat33),
2259
- v4: wp.array(dtype=mat44),
2260
- v5: wp.array(dtype=mat55),
2261
- outcomponents: wp.array(dtype=wptype),
2262
- ):
2263
- v2result = v2[0] - s2[0]
2264
- v3result = v3[0] - s3[0]
2265
- v4result = v4[0] - s4[0]
2266
- v5result = v5[0] - s5[0]
2267
-
2268
- # multiply outputs by 2 so we've got something to backpropagate:
2269
- idx = 0
2270
- for i in range(2):
2271
- for j in range(2):
2272
- outcomponents[idx] = wptype(2) * v2result[i, j]
2273
- idx = idx + 1
2274
-
2275
- for i in range(3):
2276
- for j in range(3):
2277
- outcomponents[idx] = wptype(2) * v3result[i, j]
2278
- idx = idx + 1
2279
-
2280
- for i in range(4):
2281
- for j in range(4):
2282
- outcomponents[idx] = wptype(2) * v4result[i, j]
2283
- idx = idx + 1
2284
-
2285
- for i in range(5):
2286
- for j in range(5):
2287
- outcomponents[idx] = wptype(2) * v5result[i, j]
2288
- idx = idx + 1
2289
-
2290
- kernel = getkernel(check_mat_sub, suffix=dtype.__name__)
2291
-
2292
- if register_kernels:
2293
- return
2294
-
2295
- s2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2296
- s3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2297
- s4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2298
- s5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2299
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2300
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2301
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2302
- v5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
520
+ s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
521
+ s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
522
+ s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
523
+ s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
524
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
525
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
526
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
527
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2303
528
  outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
2304
529
 
2305
530
  wp.launch(
@@ -2362,131 +587,8 @@ def test_subtraction(test, device, dtype, register_kernels=False):
2362
587
  idx = idx + 1
2363
588
 
2364
589
 
2365
- def test_ddot(test, device, dtype, register_kernels=False):
2366
- np.random.seed(123)
2367
-
2368
- tol = {
2369
- np.float16: 5.0e-3,
2370
- np.float32: 1.0e-6,
2371
- np.float64: 1.0e-8,
2372
- }.get(dtype, 0)
2373
-
2374
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2375
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2376
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2377
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2378
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2379
-
2380
- def check_mat_dot(
2381
- s2: wp.array(dtype=mat22),
2382
- s3: wp.array(dtype=mat33),
2383
- s4: wp.array(dtype=mat44),
2384
- s5: wp.array(dtype=mat55),
2385
- v2: wp.array(dtype=mat22),
2386
- v3: wp.array(dtype=mat33),
2387
- v4: wp.array(dtype=mat44),
2388
- v5: wp.array(dtype=mat55),
2389
- dot2: wp.array(dtype=wptype),
2390
- dot3: wp.array(dtype=wptype),
2391
- dot4: wp.array(dtype=wptype),
2392
- dot5: wp.array(dtype=wptype),
2393
- ):
2394
- # multiply outputs by 2 so we've got something to backpropagate:
2395
- dot2[0] = wptype(2) * wp.ddot(v2[0], s2[0])
2396
- dot3[0] = wptype(2) * wp.ddot(v3[0], s3[0])
2397
- dot4[0] = wptype(2) * wp.ddot(v4[0], s4[0])
2398
- dot5[0] = wptype(2) * wp.ddot(v5[0], s5[0])
2399
-
2400
- kernel = getkernel(check_mat_dot, suffix=dtype.__name__)
2401
-
2402
- if register_kernels:
2403
- return
2404
-
2405
- s2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2406
- s3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2407
- s4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2408
- s5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2409
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2410
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2411
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2412
- v5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2413
- dot2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2414
- dot3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2415
- dot4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2416
- dot5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2417
-
2418
- tape = wp.Tape()
2419
- with tape:
2420
- wp.launch(
2421
- kernel,
2422
- dim=1,
2423
- inputs=[
2424
- s2,
2425
- s3,
2426
- s4,
2427
- s5,
2428
- v2,
2429
- v3,
2430
- v4,
2431
- v5,
2432
- ],
2433
- outputs=[dot2, dot3, dot4, dot5],
2434
- device=device,
2435
- )
2436
-
2437
- assert_np_equal(dot2.numpy()[0], 2 * (v2.numpy() * s2.numpy()).sum(), tol=10 * tol)
2438
- assert_np_equal(dot3.numpy()[0], 2 * (v3.numpy() * s3.numpy()).sum(), tol=10 * tol)
2439
- assert_np_equal(dot4.numpy()[0], 2 * (v4.numpy() * s4.numpy()).sum(), tol=50 * tol)
2440
- assert_np_equal(dot5.numpy()[0], 2 * (v5.numpy() * s5.numpy()).sum(), tol=200 * tol)
2441
-
2442
- if dtype in np_float_types:
2443
- tape.backward(loss=dot2)
2444
- sgrads = tape.gradients[s2].numpy()[0]
2445
- expected_grads = 2.0 * v2.numpy()[0]
2446
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2447
-
2448
- vgrads = tape.gradients[v2].numpy()[0]
2449
- expected_grads = 2.0 * s2.numpy()[0]
2450
- assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2451
-
2452
- tape.zero()
2453
-
2454
- tape.backward(loss=dot3)
2455
- sgrads = tape.gradients[s3].numpy()[0]
2456
- expected_grads = 2.0 * v3.numpy()[0]
2457
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2458
-
2459
- vgrads = tape.gradients[v3].numpy()[0]
2460
- expected_grads = 2.0 * s3.numpy()[0]
2461
- assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2462
-
2463
- tape.zero()
2464
-
2465
- tape.backward(loss=dot4)
2466
- sgrads = tape.gradients[s4].numpy()[0]
2467
- expected_grads = 2.0 * v4.numpy()[0]
2468
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2469
-
2470
- vgrads = tape.gradients[v4].numpy()[0]
2471
- expected_grads = 2.0 * s4.numpy()[0]
2472
- assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2473
-
2474
- tape.zero()
2475
-
2476
- tape.backward(loss=dot5)
2477
- sgrads = tape.gradients[s5].numpy()[0]
2478
- expected_grads = 2.0 * v5.numpy()[0]
2479
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2480
-
2481
- vgrads = tape.gradients[v5].numpy()[0]
2482
- expected_grads = 2.0 * s5.numpy()[0]
2483
- assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2484
-
2485
- tape.zero()
2486
-
2487
-
2488
590
  def test_determinant(test, device, dtype, register_kernels=False):
2489
- np.random.seed(123)
591
+ rng = np.random.default_rng(123)
2490
592
 
2491
593
  tol = {
2492
594
  np.float16: 5.0e-3,
@@ -2516,9 +618,9 @@ def test_determinant(test, device, dtype, register_kernels=False):
2516
618
  if register_kernels:
2517
619
  return
2518
620
 
2519
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2520
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2521
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
621
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
622
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
623
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2522
624
  det2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2523
625
  det3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2524
626
  det4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -2637,266 +739,115 @@ def test_determinant(test, device, dtype, register_kernels=False):
2637
739
  v4,
2638
740
  ],
2639
741
  outputs=[
2640
- det2,
2641
- det3,
2642
- det4,
2643
- ],
2644
- device=device,
2645
- )
2646
- dminus = det3.numpy()[0]
2647
- assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v3grads[i, j] / dplus, tol=fdtol)
2648
-
2649
- for i in range(4):
2650
- for j in range(4):
2651
- v4test = v4.numpy()
2652
- v4test[0, i, j] += dx
2653
- wp.launch(
2654
- kernel,
2655
- dim=1,
2656
- inputs=[
2657
- v2,
2658
- v3,
2659
- wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device),
2660
- ],
2661
- outputs=[
2662
- det2,
2663
- det3,
2664
- det4,
2665
- ],
2666
- device=device,
2667
- )
2668
- dplus = det4.numpy()[0]
2669
- v4test[0, i, j] -= 2.0 * dx
2670
- wp.launch(
2671
- kernel,
2672
- dim=1,
2673
- inputs=[
2674
- v2,
2675
- v3,
2676
- wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device),
2677
- ],
2678
- outputs=[
2679
- det2,
2680
- det3,
2681
- det4,
2682
- ],
2683
- device=device,
2684
- )
2685
- dminus = det4.numpy()[0]
2686
- assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v4grads[i, j] / dplus, tol=fdtol)
2687
-
2688
-
2689
- def test_trace(test, device, dtype, register_kernels=False):
2690
- np.random.seed(123)
2691
-
2692
- tol = {
2693
- np.float16: 1.0e-3,
2694
- np.float32: 1.0e-6,
2695
- np.float64: 1.0e-8,
2696
- }.get(dtype, 0)
2697
-
2698
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2699
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2700
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2701
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2702
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2703
-
2704
- def check_mat_trace(
2705
- v2: wp.array(dtype=mat22),
2706
- v3: wp.array(dtype=mat33),
2707
- v4: wp.array(dtype=mat44),
2708
- v5: wp.array(dtype=mat55),
2709
- tr2: wp.array(dtype=wptype),
2710
- tr3: wp.array(dtype=wptype),
2711
- tr4: wp.array(dtype=wptype),
2712
- tr5: wp.array(dtype=wptype),
2713
- ):
2714
- # multiply outputs by 2 so we've got something to backpropagate:
2715
- tr2[0] = wptype(2) * wp.trace(v2[0])
2716
- tr3[0] = wptype(2) * wp.trace(v3[0])
2717
- tr4[0] = wptype(2) * wp.trace(v4[0])
2718
- tr5[0] = wptype(2) * wp.trace(v5[0])
2719
-
2720
- kernel = getkernel(check_mat_trace, suffix=dtype.__name__)
2721
-
2722
- if register_kernels:
2723
- return
2724
-
2725
- v2 = wp.array(randvals([1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2726
- v3 = wp.array(randvals([1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2727
- v4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2728
- v5 = wp.array(randvals([1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2729
- tr2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2730
- tr3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2731
- tr4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2732
- tr5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2733
-
2734
- tape = wp.Tape()
2735
- with tape:
2736
- wp.launch(
2737
- kernel,
2738
- dim=1,
2739
- inputs=[
2740
- v2,
2741
- v3,
2742
- v4,
2743
- v5,
2744
- ],
2745
- outputs=[
2746
- tr2,
2747
- tr3,
2748
- tr4,
2749
- tr5,
2750
- ],
2751
- device=device,
2752
- )
2753
-
2754
- assert_np_equal(tr2.numpy()[0], 2 * np.trace(v2.numpy()[0]), tol=10 * tol)
2755
- assert_np_equal(tr3.numpy()[0], 2 * np.trace(v3.numpy()[0]), tol=10 * tol)
2756
- assert_np_equal(tr4.numpy()[0], 2 * np.trace(v4.numpy()[0]), tol=200 * tol)
2757
- assert_np_equal(tr4.numpy()[0], 2 * np.trace(v4.numpy()[0]), tol=200 * tol)
2758
-
2759
- if dtype in np_float_types:
2760
- tape.backward(loss=tr2)
2761
- vgrads = tape.gradients[v2].numpy()[0]
2762
- assert_np_equal(vgrads, 2.0 * np.eye(2), tol=10 * tol)
2763
- tape.zero()
2764
-
2765
- tape.backward(loss=tr3)
2766
- vgrads = tape.gradients[v3].numpy()[0]
2767
- assert_np_equal(vgrads, 2.0 * np.eye(3), tol=10 * tol)
2768
- tape.zero()
2769
-
2770
- tape.backward(loss=tr4)
2771
- vgrads = tape.gradients[v4].numpy()[0]
2772
- assert_np_equal(vgrads, 2.0 * np.eye(4), tol=10 * tol)
2773
- tape.zero()
2774
-
2775
- tape.backward(loss=tr5)
2776
- vgrads = tape.gradients[v5].numpy()[0]
2777
- assert_np_equal(vgrads, 2.0 * np.eye(5), tol=10 * tol)
2778
- tape.zero()
2779
-
2780
-
2781
- def test_diag(test, device, dtype, register_kernels=False):
2782
- np.random.seed(123)
2783
-
2784
- tol = {
2785
- np.float16: 1.0e-3,
2786
- np.float32: 1.0e-6,
2787
- np.float64: 1.0e-8,
2788
- }.get(dtype, 0)
2789
-
2790
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2791
- vec5 = wp.types.vector(length=5, dtype=wptype)
2792
-
2793
- output_select_kernel = get_select_kernel(wptype)
2794
-
2795
- def check_mat_diag(
2796
- s5: wp.array(dtype=vec5),
2797
- outcomponents: wp.array(dtype=wptype),
2798
- ):
2799
- # multiply outputs by 2 so we've got something to backpropagate:
2800
- m55result = wptype(2) * wp.diag(s5[0])
2801
-
2802
- idx = 0
2803
- for i in range(5):
2804
- for j in range(5):
2805
- outcomponents[idx] = m55result[i, j]
2806
- idx = idx + 1
2807
-
2808
- kernel = getkernel(check_mat_diag, suffix=dtype.__name__)
2809
-
2810
- if register_kernels:
2811
- return
2812
-
2813
- s5 = wp.array(randvals([1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
2814
- outcomponents = wp.zeros(5 * 5, dtype=wptype, requires_grad=True, device=device)
2815
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2816
-
2817
- wp.launch(kernel, dim=1, inputs=[s5], outputs=[outcomponents], device=device)
2818
-
2819
- assert_np_equal(outcomponents.numpy(), 2 * np.diag(s5.numpy()[0]), tol=tol)
2820
-
2821
- if dtype in np_float_types:
2822
- idx = 0
2823
- for i in range(5):
2824
- for j in range(5):
2825
- tape = wp.Tape()
2826
- with tape:
2827
- wp.launch(kernel, dim=1, inputs=[s5], outputs=[outcomponents], device=device)
2828
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
2829
- tape.backward(loss=out)
2830
- expectedresult = np.zeros(5, dtype=dtype)
2831
- if i == j:
2832
- expectedresult[i] = 2
2833
- assert_np_equal(tape.gradients[s5].numpy()[0], expectedresult, tol=10 * tol)
2834
- tape.zero()
2835
-
2836
- idx = idx + 1
2837
-
2838
-
2839
- def test_get_diag(test, device, dtype, register_kernels=False):
2840
- np.random.seed(123)
2841
-
2842
- tol = {
2843
- np.float16: 1.0e-3,
2844
- np.float32: 1.0e-6,
2845
- np.float64: 1.0e-8,
2846
- }.get(dtype, 0)
2847
-
2848
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2849
- mat55 = wp.types.vector(shape=(5, 5), dtype=wptype)
2850
-
2851
- output_select_kernel = get_select_kernel(wptype)
2852
-
2853
- def check_mat_diag(
2854
- m55: wp.array(dtype=mat55),
2855
- outcomponents: wp.array(dtype=wptype),
2856
- ):
2857
- # multiply outputs by 2 so we've got something to backpropagate:
2858
- vec5result = wptype(2) * wp.get_diag(m55[0])
2859
-
2860
- idx = 0
2861
- for i in range(5):
2862
- outcomponents[idx] = vec5result[i]
2863
- idx = idx + 1
2864
-
2865
- kernel = getkernel(check_mat_diag, suffix=dtype.__name__)
2866
-
2867
- if register_kernels:
2868
- return
2869
-
2870
- m55 = wp.array(randvals((1, 5, 5), dtype), dtype=mat55, requires_grad=True, device=device)
2871
- outcomponents = wp.zeros(5, dtype=wptype, requires_grad=True, device=device)
2872
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2873
-
2874
- wp.launch(kernel, dim=1, inputs=[m55], outputs=[outcomponents], device=device)
742
+ det2,
743
+ det3,
744
+ det4,
745
+ ],
746
+ device=device,
747
+ )
748
+ dminus = det3.numpy()[0]
749
+ assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v3grads[i, j] / dplus, tol=fdtol)
2875
750
 
2876
- assert_np_equal(outcomponents.numpy(), 2 * np.diag(m55.numpy()[0]), tol=tol)
751
+ for i in range(4):
752
+ for j in range(4):
753
+ v4test = v4.numpy()
754
+ v4test[0, i, j] += dx
755
+ wp.launch(
756
+ kernel,
757
+ dim=1,
758
+ inputs=[
759
+ v2,
760
+ v3,
761
+ wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device),
762
+ ],
763
+ outputs=[
764
+ det2,
765
+ det3,
766
+ det4,
767
+ ],
768
+ device=device,
769
+ )
770
+ dplus = det4.numpy()[0]
771
+ v4test[0, i, j] -= 2.0 * dx
772
+ wp.launch(
773
+ kernel,
774
+ dim=1,
775
+ inputs=[
776
+ v2,
777
+ v3,
778
+ wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device),
779
+ ],
780
+ outputs=[
781
+ det2,
782
+ det3,
783
+ det4,
784
+ ],
785
+ device=device,
786
+ )
787
+ dminus = det4.numpy()[0]
788
+ assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v4grads[i, j] / dplus, tol=fdtol)
2877
789
 
2878
- if dtype in np_float_types:
2879
- idx = 0
2880
- for i in range(5):
2881
- tape = wp.Tape()
2882
- with tape:
2883
- wp.launch(kernel, dim=1, inputs=[m55], outputs=[outcomponents], device=device)
2884
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
2885
- tape.backward(loss=out)
2886
- expectedresult = np.zeros((5, 5), dtype=dtype)
2887
- expectedresult[i, i] = 2
2888
- assert_np_equal(tape.gradients[m55].numpy()[0], expectedresult, tol=10 * tol)
2889
- tape.zero()
2890
790
 
2891
- idx = idx + 1
791
+ # Unused. Why?
792
+ # def test_get_diag(test, device, dtype, register_kernels=False):
793
+ # tol = {
794
+ # np.float16: 1.0e-3,
795
+ # np.float32: 1.0e-6,
796
+ # np.float64: 1.0e-8,
797
+ # }.get(dtype, 0)
798
+ #
799
+ # wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
800
+ # mat55 = wp.types.vector(shape=(5, 5), dtype=wptype)
801
+ #
802
+ # output_select_kernel = get_select_kernel(wptype)
803
+ #
804
+ # def check_mat_diag(
805
+ # m55: wp.array(dtype=mat55),
806
+ # outcomponents: wp.array(dtype=wptype),
807
+ # ):
808
+ # # multiply outputs by 2 so we've got something to backpropagate:
809
+ # vec5result = wptype(2) * wp.get_diag(m55[0])
810
+ #
811
+ # idx = 0
812
+ # for i in range(5):
813
+ # outcomponents[idx] = vec5result[i]
814
+ # idx = idx + 1
815
+ #
816
+ # kernel = getkernel(check_mat_diag, suffix=dtype.__name__)
817
+ #
818
+ # if register_kernels:
819
+ # return
820
+ #
821
+ # m55 = wp.array(randvals((1, 5, 5), dtype), dtype=mat55, requires_grad=True, device=device)
822
+ # outcomponents = wp.zeros(5, dtype=wptype, requires_grad=True, device=device)
823
+ # out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
824
+ #
825
+ # wp.launch(kernel, dim=1, inputs=[m55], outputs=[outcomponents], device=device)
826
+ #
827
+ # assert_np_equal(outcomponents.numpy(), 2 * np.diag(m55.numpy()[0]), tol=tol)
828
+ #
829
+ # if dtype in np_float_types:
830
+ # idx = 0
831
+ # for i in range(5):
832
+ # tape = wp.Tape()
833
+ # with tape:
834
+ # wp.launch(kernel, dim=1, inputs=[m55], outputs=[outcomponents], device=device)
835
+ # wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
836
+ # tape.backward(loss=out)
837
+ # expectedresult = np.zeros((5, 5), dtype=dtype)
838
+ # expectedresult[i, i] = 2
839
+ # assert_np_equal(tape.gradients[m55].numpy()[0], expectedresult, tol=10 * tol)
840
+ # tape.zero()
841
+ #
842
+ # idx = idx + 1
2892
843
 
2893
844
 
2894
845
  def test_inverse(test, device, dtype, register_kernels=False):
2895
- np.random.seed(123)
846
+ rng = np.random.default_rng(123)
2896
847
 
2897
848
  tol = {
2898
- np.float16: 2.0e-3,
2899
- np.float32: 1.0e-6,
849
+ np.float16: 5.0e-2,
850
+ np.float32: 1.0e-5,
2900
851
  np.float64: 1.0e-8,
2901
852
  }.get(dtype, 0)
2902
853
 
@@ -2939,9 +890,15 @@ def test_inverse(test, device, dtype, register_kernels=False):
2939
890
  if register_kernels:
2940
891
  return
2941
892
 
2942
- m2 = wp.array(2 * (randvals([1, 2, 2], dtype) + 0.2 * np.eye(2)), dtype=mat22, requires_grad=True, device=device)
2943
- m3 = wp.array(2 * (randvals([1, 3, 3], dtype) + 0.2 * np.eye(3)), dtype=mat33, requires_grad=True, device=device)
2944
- m4 = wp.array(2 * (randvals([1, 4, 4], dtype) + 0.2 * np.eye(4)), dtype=mat44, requires_grad=True, device=device)
893
+ m2 = wp.array(
894
+ 2 * (randvals(rng, [1, 2, 2], dtype) + 0.2 * np.eye(2)), dtype=mat22, requires_grad=True, device=device
895
+ )
896
+ m3 = wp.array(
897
+ 2 * (randvals(rng, [1, 3, 3], dtype) + 0.2 * np.eye(3)), dtype=mat33, requires_grad=True, device=device
898
+ )
899
+ m4 = wp.array(
900
+ 2 * (randvals(rng, [1, 4, 4], dtype) + 0.2 * np.eye(4)), dtype=mat44, requires_grad=True, device=device
901
+ )
2945
902
 
2946
903
  outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4, dtype=wptype, requires_grad=True, device=device)
2947
904
  out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -3056,7 +1013,7 @@ def test_inverse(test, device, dtype, register_kernels=False):
3056
1013
 
3057
1014
 
3058
1015
  def test_svd(test, device, dtype, register_kernels=False):
3059
- np.random.seed(123)
1016
+ rng = np.random.default_rng(123)
3060
1017
 
3061
1018
  tol = {
3062
1019
  np.float16: 1.0e-3,
@@ -3108,7 +1065,7 @@ def test_svd(test, device, dtype, register_kernels=False):
3108
1065
  if register_kernels:
3109
1066
  return
3110
1067
 
3111
- m3 = wp.array(randvals([1, 3, 3], dtype) + np.eye(3), dtype=mat33, requires_grad=True, device=device)
1068
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype) + np.eye(3), dtype=mat33, requires_grad=True, device=device)
3112
1069
 
3113
1070
  outcomponents = wp.zeros(2 * 3 * 3 + 3, dtype=wptype, requires_grad=True, device=device)
3114
1071
  Uout = wp.zeros(1, dtype=mat33, requires_grad=True, device=device)
@@ -3175,7 +1132,7 @@ def test_svd(test, device, dtype, register_kernels=False):
3175
1132
 
3176
1133
 
3177
1134
  def test_qr(test, device, dtype, register_kernels=False):
3178
- np.random.seed(123)
1135
+ rng = np.random.default_rng(123)
3179
1136
 
3180
1137
  tol = {
3181
1138
  np.float16: 2.0e-3,
@@ -3218,7 +1175,7 @@ def test_qr(test, device, dtype, register_kernels=False):
3218
1175
  if register_kernels:
3219
1176
  return
3220
1177
 
3221
- m3 = wp.array(0.5 * (randvals([1, 3, 3], dtype) + np.eye(3)), dtype=mat33, requires_grad=True, device=device)
1178
+ m3 = wp.array(0.5 * (randvals(rng, [1, 3, 3], dtype) + np.eye(3)), dtype=mat33, requires_grad=True, device=device)
3222
1179
 
3223
1180
  outcomponents = wp.zeros(2 * 3 * 3, dtype=wptype, requires_grad=True, device=device)
3224
1181
  Qout = wp.zeros(1, dtype=mat33, requires_grad=True, device=device)
@@ -3287,7 +1244,7 @@ def test_qr(test, device, dtype, register_kernels=False):
3287
1244
 
3288
1245
 
3289
1246
  def test_eig(test, device, dtype, register_kernels=False):
3290
- np.random.seed(123)
1247
+ rng = np.random.default_rng(123)
3291
1248
 
3292
1249
  tol = {
3293
1250
  np.float16: 4.0e-2,
@@ -3330,7 +1287,7 @@ def test_eig(test, device, dtype, register_kernels=False):
3330
1287
  if register_kernels:
3331
1288
  return
3332
1289
 
3333
- m3_np = randvals([1, 3, 3], dtype) + np.eye(3, dtype=dtype)
1290
+ m3_np = randvals(rng, [1, 3, 3], dtype) + np.eye(3, dtype=dtype)
3334
1291
  m3 = wp.array(m3_np, dtype=mat33, requires_grad=True, device=device)
3335
1292
 
3336
1293
  outcomponents = wp.zeros(3 * 3 + 3, dtype=wptype, requires_grad=True, device=device)
@@ -3399,7 +1356,7 @@ def test_eig(test, device, dtype, register_kernels=False):
3399
1356
 
3400
1357
 
3401
1358
  def test_skew(test, device, dtype, register_kernels=False):
3402
- np.random.seed(123)
1359
+ rng = np.random.default_rng(123)
3403
1360
 
3404
1361
  tol = {
3405
1362
  np.float16: 1.0e-3,
@@ -3430,7 +1387,7 @@ def test_skew(test, device, dtype, register_kernels=False):
3430
1387
  if register_kernels:
3431
1388
  return
3432
1389
 
3433
- v3 = wp.array(randvals([1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1390
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
3434
1391
 
3435
1392
  outcomponents = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
3436
1393
  out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -3501,7 +1458,7 @@ def test_skew(test, device, dtype, register_kernels=False):
3501
1458
 
3502
1459
 
3503
1460
  def test_transform_point(test, device, dtype, register_kernels=False):
3504
- np.random.seed(123)
1461
+ rng = np.random.default_rng(123)
3505
1462
 
3506
1463
  tol = {
3507
1464
  np.float16: 5.0e-3,
@@ -3532,8 +1489,8 @@ def test_transform_point(test, device, dtype, register_kernels=False):
3532
1489
  if register_kernels:
3533
1490
  return
3534
1491
 
3535
- v3 = wp.array(randvals([1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
3536
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1492
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1493
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
3537
1494
 
3538
1495
  outcomponents = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
3539
1496
  out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -3562,7 +1519,7 @@ def test_transform_point(test, device, dtype, register_kernels=False):
3562
1519
 
3563
1520
 
3564
1521
  def test_transform_vector(test, device, dtype, register_kernels=False):
3565
- np.random.seed(123)
1522
+ rng = np.random.default_rng(123)
3566
1523
 
3567
1524
  tol = {
3568
1525
  np.float16: 5.0e-3,
@@ -3593,8 +1550,8 @@ def test_transform_vector(test, device, dtype, register_kernels=False):
3593
1550
  if register_kernels:
3594
1551
  return
3595
1552
 
3596
- v3 = wp.array(randvals([1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
3597
- m4 = wp.array(randvals([1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1553
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1554
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
3598
1555
 
3599
1556
  outcomponents = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
3600
1557
  out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
@@ -3621,338 +1578,6 @@ def test_transform_vector(test, device, dtype, register_kernels=False):
3621
1578
  tape.zero()
3622
1579
 
3623
1580
 
3624
- def test_anon_type_instance(test, device, dtype, register_kernels=False):
3625
- np.random.seed(123)
3626
-
3627
- tol = {
3628
- np.float16: 5.0e-3,
3629
- np.float32: 1.0e-6,
3630
- np.float64: 1.0e-8,
3631
- }.get(dtype, 0)
3632
-
3633
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
3634
-
3635
- def check_scalar_init(
3636
- input: wp.array(dtype=wptype),
3637
- output: wp.array(dtype=wptype),
3638
- ):
3639
- m2result = wp.matrix(input[0], shape=(2, 2))
3640
- m3result = wp.matrix(input[1], shape=(3, 3))
3641
- m4result = wp.matrix(input[2], shape=(4, 4))
3642
- m5result = wp.matrix(input[3], shape=(5, 5))
3643
- m32result = wp.matrix(input[4], shape=(3, 2))
3644
-
3645
- idx = 0
3646
- for i in range(2):
3647
- for j in range(2):
3648
- output[idx] = wptype(2) * m2result[i, j]
3649
- idx = idx + 1
3650
- for i in range(3):
3651
- for j in range(3):
3652
- output[idx] = wptype(2) * m3result[i, j]
3653
- idx = idx + 1
3654
- for i in range(4):
3655
- for j in range(4):
3656
- output[idx] = wptype(2) * m4result[i, j]
3657
- idx = idx + 1
3658
- for i in range(5):
3659
- for j in range(5):
3660
- output[idx] = wptype(2) * m5result[i, j]
3661
- idx = idx + 1
3662
- for i in range(3):
3663
- for j in range(2):
3664
- output[idx] = wptype(2) * m32result[i, j]
3665
- idx = idx + 1
3666
-
3667
- def check_component_init(
3668
- input: wp.array(dtype=wptype),
3669
- output: wp.array(dtype=wptype),
3670
- ):
3671
- m2result = wp.matrix(input[0], input[1], input[2], input[3], shape=(2, 2))
3672
- m3result = wp.matrix(
3673
- input[4], input[5], input[6], input[7], input[8], input[9], input[10], input[11], input[12], shape=(3, 3)
3674
- )
3675
- m4result = wp.matrix(
3676
- input[13],
3677
- input[14],
3678
- input[15],
3679
- input[16],
3680
- input[17],
3681
- input[18],
3682
- input[19],
3683
- input[20],
3684
- input[21],
3685
- input[22],
3686
- input[23],
3687
- input[24],
3688
- input[25],
3689
- input[26],
3690
- input[27],
3691
- input[28],
3692
- shape=(4, 4),
3693
- )
3694
- m5result = wp.matrix(
3695
- input[29],
3696
- input[30],
3697
- input[31],
3698
- input[32],
3699
- input[33],
3700
- input[34],
3701
- input[35],
3702
- input[36],
3703
- input[37],
3704
- input[38],
3705
- input[39],
3706
- input[40],
3707
- input[41],
3708
- input[42],
3709
- input[43],
3710
- input[44],
3711
- input[45],
3712
- input[46],
3713
- input[47],
3714
- input[48],
3715
- input[49],
3716
- input[50],
3717
- input[51],
3718
- input[52],
3719
- input[53],
3720
- shape=(5, 5),
3721
- )
3722
- m32result = wp.matrix(input[54], input[55], input[56], input[57], input[58], input[59], shape=(3, 2))
3723
-
3724
- idx = 0
3725
- for i in range(2):
3726
- for j in range(2):
3727
- output[idx] = wptype(2) * m2result[i, j]
3728
- idx = idx + 1
3729
- for i in range(3):
3730
- for j in range(3):
3731
- output[idx] = wptype(2) * m3result[i, j]
3732
- idx = idx + 1
3733
- for i in range(4):
3734
- for j in range(4):
3735
- output[idx] = wptype(2) * m4result[i, j]
3736
- idx = idx + 1
3737
- for i in range(5):
3738
- for j in range(5):
3739
- output[idx] = wptype(2) * m5result[i, j]
3740
- idx = idx + 1
3741
- for i in range(3):
3742
- for j in range(2):
3743
- output[idx] = wptype(2) * m32result[i, j]
3744
- idx = idx + 1
3745
-
3746
- scalar_kernel = getkernel(check_scalar_init, suffix=dtype.__name__)
3747
- component_kernel = getkernel(check_component_init, suffix=dtype.__name__)
3748
- output_select_kernel = get_select_kernel(wptype)
3749
-
3750
- if register_kernels:
3751
- return
3752
-
3753
- input = wp.array(randvals([5], dtype), requires_grad=True, device=device)
3754
- output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2, dtype=wptype, requires_grad=True, device=device)
3755
-
3756
- wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
3757
-
3758
- assert_np_equal(output.numpy()[:4], 2 * np.array([input.numpy()[0]] * 2 * 2), tol=1.0e-6)
3759
- assert_np_equal(output.numpy()[4:13], 2 * np.array([input.numpy()[1]] * 3 * 3), tol=1.0e-6)
3760
- assert_np_equal(output.numpy()[13:29], 2 * np.array([input.numpy()[2]] * 4 * 4), tol=1.0e-6)
3761
- assert_np_equal(output.numpy()[29:54], 2 * np.array([input.numpy()[3]] * 5 * 5), tol=1.0e-6)
3762
- assert_np_equal(output.numpy()[54:], 2 * np.array([input.numpy()[4]] * 3 * 2), tol=1.0e-6)
3763
-
3764
- if dtype in np_float_types:
3765
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
3766
- for i in range(len(output)):
3767
- tape = wp.Tape()
3768
- with tape:
3769
- wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
3770
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
3771
-
3772
- tape.backward(loss=out)
3773
- expected = np.zeros_like(input.numpy())
3774
- if i < 4:
3775
- expected[0] = 2
3776
- elif i < 13:
3777
- expected[1] = 2
3778
- elif i < 29:
3779
- expected[2] = 2
3780
- elif i < 54:
3781
- expected[3] = 2
3782
- else:
3783
- expected[4] = 2
3784
-
3785
- assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
3786
-
3787
- tape.reset()
3788
- tape.zero()
3789
-
3790
- input = wp.array(randvals([2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2], dtype), requires_grad=True, device=device)
3791
- output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2, dtype=wptype, requires_grad=True, device=device)
3792
-
3793
- wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
3794
-
3795
- assert_np_equal(output.numpy(), 2 * input.numpy(), tol=1.0e-6)
3796
-
3797
- if dtype in np_float_types:
3798
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
3799
- for i in range(len(output)):
3800
- tape = wp.Tape()
3801
- with tape:
3802
- wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
3803
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
3804
-
3805
- tape.backward(loss=out)
3806
- expected = np.zeros_like(input.numpy())
3807
- expected[i] = 2
3808
-
3809
- assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
3810
-
3811
- tape.reset()
3812
- tape.zero()
3813
-
3814
-
3815
- def test_identity(test, device, dtype, register_kernels=False):
3816
- np.random.seed(123)
3817
-
3818
- tol = {
3819
- np.float16: 5.0e-3,
3820
- np.float32: 1.0e-6,
3821
- np.float64: 1.0e-8,
3822
- }.get(dtype, 0)
3823
-
3824
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
3825
-
3826
- def check_identity_mat(
3827
- output: wp.array(dtype=wptype),
3828
- ):
3829
- m2result = wp.identity(dtype=wptype, n=2)
3830
- m3result = wp.identity(dtype=wptype, n=3)
3831
- m4result = wp.identity(dtype=wptype, n=4)
3832
- m5result = wp.identity(dtype=wptype, n=5)
3833
-
3834
- idx = 0
3835
- for i in range(2):
3836
- for j in range(2):
3837
- output[idx] = wptype(2) * m2result[i, j]
3838
- idx = idx + 1
3839
- for i in range(3):
3840
- for j in range(3):
3841
- output[idx] = wptype(2) * m3result[i, j]
3842
- idx = idx + 1
3843
- for i in range(4):
3844
- for j in range(4):
3845
- output[idx] = wptype(2) * m4result[i, j]
3846
- idx = idx + 1
3847
- for i in range(5):
3848
- for j in range(5):
3849
- output[idx] = wptype(2) * m5result[i, j]
3850
- idx = idx + 1
3851
-
3852
- id_kernel = getkernel(check_identity_mat, suffix=dtype.__name__)
3853
-
3854
- if register_kernels:
3855
- return
3856
-
3857
- output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
3858
- wp.launch(id_kernel, dim=1, inputs=[], outputs=[output], device=device)
3859
- assert_np_equal(output.numpy()[:4], 2 * np.eye(2), tol=1.0e-6)
3860
- assert_np_equal(output.numpy()[4:13], 2 * np.eye(3), tol=1.0e-6)
3861
- assert_np_equal(output.numpy()[13:29], 2 * np.eye(4), tol=1.0e-6)
3862
- assert_np_equal(output.numpy()[29:], 2 * np.eye(5), tol=1.0e-6)
3863
-
3864
-
3865
- def test_equivalent_types(test, device, dtype, register_kernels=False):
3866
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
3867
-
3868
- # matrix types
3869
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
3870
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
3871
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
3872
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
3873
-
3874
- # matrix types equivalent to the above
3875
- mat22_equiv = wp.types.matrix(shape=(2, 2), dtype=wptype)
3876
- mat33_equiv = wp.types.matrix(shape=(3, 3), dtype=wptype)
3877
- mat44_equiv = wp.types.matrix(shape=(4, 4), dtype=wptype)
3878
- mat55_equiv = wp.types.matrix(shape=(5, 5), dtype=wptype)
3879
-
3880
- # declare kernel with original types
3881
- def check_equivalence(
3882
- m2: mat22,
3883
- m3: mat33,
3884
- m4: mat44,
3885
- m5: mat55,
3886
- ):
3887
- wp.expect_eq(m2, mat22(wptype(42)))
3888
- wp.expect_eq(m3, mat33(wptype(43)))
3889
- wp.expect_eq(m4, mat44(wptype(44)))
3890
- wp.expect_eq(m5, mat55(wptype(45)))
3891
-
3892
- wp.expect_eq(m2, mat22_equiv(wptype(42)))
3893
- wp.expect_eq(m3, mat33_equiv(wptype(43)))
3894
- wp.expect_eq(m4, mat44_equiv(wptype(44)))
3895
- wp.expect_eq(m5, mat55_equiv(wptype(45)))
3896
-
3897
- kernel = getkernel(check_equivalence, suffix=dtype.__name__)
3898
-
3899
- if register_kernels:
3900
- return
3901
-
3902
- # call kernel with equivalent types
3903
- m2 = mat22_equiv(42)
3904
- m3 = mat33_equiv(43)
3905
- m4 = mat44_equiv(44)
3906
- m5 = mat55_equiv(45)
3907
-
3908
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], device=device)
3909
-
3910
-
3911
- def test_conversions(test, device, dtype, register_kernels=False):
3912
- def check_matrices_equal(
3913
- m0: wp.mat22,
3914
- m1: wp.mat22,
3915
- m2: wp.mat22,
3916
- m3: wp.mat22,
3917
- m4: wp.mat22,
3918
- m5: wp.mat22,
3919
- m6: wp.mat22,
3920
- ):
3921
- wp.expect_eq(m1, m0)
3922
- wp.expect_eq(m2, m0)
3923
- wp.expect_eq(m3, m0)
3924
- wp.expect_eq(m4, m0)
3925
- wp.expect_eq(m5, m0)
3926
- wp.expect_eq(m6, m0)
3927
-
3928
- kernel = getkernel(check_matrices_equal, suffix=dtype.__name__)
3929
-
3930
- if register_kernels:
3931
- return
3932
-
3933
- m0 = wp.mat22(1, 2, 3, 4)
3934
-
3935
- # test explicit conversions - constructing matrices from different containers
3936
- m1 = wp.mat22(((1, 2), (3, 4))) # nested tuples
3937
- m2 = wp.mat22([[1, 2], [3, 4]]) # nested lists
3938
- m3 = wp.mat22(np.array([[1, 2], [3, 4]], dtype=dtype)) # 2d array
3939
- m4 = wp.mat22((1, 2, 3, 4)) # flat tuple
3940
- m5 = wp.mat22([1, 2, 3, 4]) # flat list
3941
- m6 = wp.mat22(np.array([1, 2, 3, 4], dtype=dtype)) # 1d array
3942
-
3943
- wp.launch(kernel, dim=1, inputs=[m0, m1, m2, m3, m4, m5, m6], device=device)
3944
-
3945
- # test implicit conversions - passing different containers as matrices to wp.launch()
3946
- m1 = ((1, 2), (3, 4)) # nested tuples
3947
- m2 = [[1, 2], [3, 4]] # nested lists
3948
- m3 = np.array([[1, 2], [3, 4]], dtype=dtype) # 2d array
3949
- m4 = (1, 2, 3, 4) # flat tuple
3950
- m5 = [1, 2, 3, 4] # flat list
3951
- m6 = np.array([1, 2, 3, 4], dtype=dtype) # 1d array
3952
-
3953
- wp.launch(kernel, dim=1, inputs=[m0, m1, m2, m3, m4, m5, m6], device=device)
3954
-
3955
-
3956
1581
  # Test matrix constructors using explicit type (float16)
3957
1582
  # note that these tests are specifically not using generics / closure
3958
1583
  # args to create kernels dynamically (like the rest of this file)
@@ -3976,6 +1601,22 @@ def test_constructors_explicit_precision():
3976
1601
  wp.expect_eq(custom[i, j], wp.float16(i) * wp.float16(2.0) + wp.float16(j))
3977
1602
 
3978
1603
 
1604
+ mat32d = wp.mat(shape=(3, 2), dtype=wp.float64)
1605
+
1606
+
1607
+ @wp.kernel
1608
+ def test_matrix_constructor_value_func():
1609
+ a = wp.mat22()
1610
+ b = wp.matrix(a, shape=(2, 2))
1611
+ c = mat32d()
1612
+ d = mat32d(c, shape=(3, 2))
1613
+ e = mat32d(wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0))
1614
+ f = mat32d(
1615
+ wp.vec3d(wp.float64(1.0), wp.float64(2.0), wp.float64(3.0)),
1616
+ wp.vec3d(wp.float64(1.0), wp.float64(2.0), wp.float64(3.0)),
1617
+ )
1618
+
1619
+
3979
1620
  # Same as above but with a default (float/int) type
3980
1621
  # which tests some different code paths that
3981
1622
  # need to ensure types are correctly canonicalized
@@ -4038,171 +1679,149 @@ def test_constructors_constant_shape():
4038
1679
  m[i, j] = float(i * j)
4039
1680
 
4040
1681
 
4041
- def register(parent):
4042
- devices = get_test_devices()
4043
-
4044
- class TestMat(parent):
4045
- pass
4046
-
4047
- add_kernel_test(TestMat, test_constructors_explicit_precision, dim=1, devices=devices)
4048
- add_kernel_test(TestMat, test_constructors_default_precision, dim=1, devices=devices)
4049
- add_kernel_test(TestMat, test_constructors_constant_shape, dim=1, devices=devices)
4050
-
4051
- mat103 = wp.types.matrix(shape=(10, 3), dtype=float)
4052
- add_kernel_test(
4053
- TestMat,
4054
- test_matrix_mutation,
4055
- dim=1,
4056
- inputs=[
4057
- mat103(
4058
- 1.0,
4059
- 2.0,
4060
- 3.0,
4061
- 2.0,
4062
- 4.0,
4063
- 6.0,
4064
- 3.0,
4065
- 6.0,
4066
- 9.0,
4067
- 4.0,
4068
- 8.0,
4069
- 12.0,
4070
- 5.0,
4071
- 10.0,
4072
- 15.0,
4073
- 6.0,
4074
- 12.0,
4075
- 18.0,
4076
- 7.0,
4077
- 14.0,
4078
- 21.0,
4079
- 8.0,
4080
- 16.0,
4081
- 24.0,
4082
- 9.0,
4083
- 18.0,
4084
- 27.0,
4085
- 10.0,
4086
- 20.0,
4087
- 30.0,
4088
- )
4089
- ],
4090
- devices=devices,
4091
- )
4092
-
4093
- for dtype in np_signed_int_types + np_float_types:
4094
- add_function_test_register_kernel(
4095
- TestMat, f"test_negation_{dtype.__name__}", test_negation, devices=devices, dtype=dtype
4096
- )
4097
- add_function_test_register_kernel(
4098
- TestMat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
4099
- )
4100
-
4101
- for dtype in np_scalar_types:
4102
- add_function_test(TestMat, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
4103
- add_function_test(TestMat, f"test_components_{dtype.__name__}", test_components, devices=None, dtype=dtype)
4104
- add_function_test_register_kernel(
4105
- TestMat, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
4106
- )
4107
- add_function_test_register_kernel(
4108
- TestMat, f"test_anon_type_instance_{dtype.__name__}", test_anon_type_instance, devices=devices, dtype=dtype
4109
- )
4110
- add_function_test_register_kernel(
4111
- TestMat, f"test_identity_{dtype.__name__}", test_identity, devices=devices, dtype=dtype
4112
- )
4113
- add_function_test_register_kernel(
4114
- TestMat, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
4115
- )
4116
- add_function_test_register_kernel(
4117
- TestMat, f"test_equality_{dtype.__name__}", test_equality, devices=devices, dtype=dtype
4118
- )
4119
- add_function_test_register_kernel(
4120
- TestMat,
4121
- f"test_scalar_multiplication_{dtype.__name__}",
4122
- test_scalar_multiplication,
4123
- devices=devices,
4124
- dtype=dtype,
4125
- )
4126
- add_function_test_register_kernel(
4127
- TestMat,
4128
- f"test_matvec_multiplication_{dtype.__name__}",
4129
- test_matvec_multiplication,
4130
- devices=devices,
4131
- dtype=dtype,
4132
- )
4133
- add_function_test_register_kernel(
4134
- TestMat,
4135
- f"test_matmat_multiplication_{dtype.__name__}",
4136
- test_matmat_multiplication,
4137
- devices=devices,
4138
- dtype=dtype,
4139
- )
4140
- add_function_test_register_kernel(
4141
- TestMat, f"test_cw_multiplication_{dtype.__name__}", test_cw_multiplication, devices=devices, dtype=dtype
4142
- )
4143
- add_function_test_register_kernel(
4144
- TestMat, f"test_cw_division_{dtype.__name__}", test_cw_division, devices=devices, dtype=dtype
4145
- )
4146
- add_function_test_register_kernel(
4147
- TestMat, f"test_outer_product_{dtype.__name__}", test_outer_product, devices=devices, dtype=dtype
4148
- )
4149
- add_function_test_register_kernel(
4150
- TestMat, f"test_transpose_{dtype.__name__}", test_transpose, devices=devices, dtype=dtype
4151
- )
4152
- add_function_test_register_kernel(
4153
- TestMat, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
4154
- )
4155
- add_function_test_register_kernel(
4156
- TestMat, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
4157
- )
4158
- add_function_test_register_kernel(
4159
- TestMat, f"test_ddot_{dtype.__name__}", test_ddot, devices=devices, dtype=dtype
4160
- )
4161
- add_function_test_register_kernel(
4162
- TestMat, f"test_trace_{dtype.__name__}", test_trace, devices=devices, dtype=dtype
4163
- )
4164
- add_function_test_register_kernel(
4165
- TestMat, f"test_diag_{dtype.__name__}", test_diag, devices=devices, dtype=dtype
4166
- )
4167
- add_function_test_register_kernel(
4168
- TestMat, f"test_get_diag_{dtype.__name__}", test_diag, devices=devices, dtype=dtype
4169
- )
4170
- add_function_test_register_kernel(
4171
- TestMat, f"test_equivalent_types_{dtype.__name__}", test_equivalent_types, devices=devices, dtype=dtype
4172
- )
4173
- add_function_test_register_kernel(
4174
- TestMat, f"test_conversions_{dtype.__name__}", test_conversions, devices=devices, dtype=dtype
4175
- )
4176
- add_function_test_register_kernel(
4177
- TestMat, f"test_constants_{dtype.__name__}", test_constants, devices=devices, dtype=dtype
1682
+ devices = get_test_devices()
1683
+
1684
+
1685
+ class TestMat(unittest.TestCase):
1686
+ pass
1687
+
1688
+
1689
+ add_kernel_test(TestMat, test_constructors_explicit_precision, dim=1, devices=devices)
1690
+ add_kernel_test(TestMat, test_constructors_default_precision, dim=1, devices=devices)
1691
+ add_kernel_test(TestMat, test_constructors_constant_shape, dim=1, devices=devices)
1692
+ add_kernel_test(TestMat, test_matrix_constructor_value_func, dim=1, devices=devices)
1693
+
1694
+ mat103 = wp.types.matrix(shape=(10, 3), dtype=float)
1695
+ add_kernel_test(
1696
+ TestMat,
1697
+ test_matrix_mutation,
1698
+ dim=1,
1699
+ inputs=[
1700
+ mat103(
1701
+ 1.0,
1702
+ 2.0,
1703
+ 3.0,
1704
+ 2.0,
1705
+ 4.0,
1706
+ 6.0,
1707
+ 3.0,
1708
+ 6.0,
1709
+ 9.0,
1710
+ 4.0,
1711
+ 8.0,
1712
+ 12.0,
1713
+ 5.0,
1714
+ 10.0,
1715
+ 15.0,
1716
+ 6.0,
1717
+ 12.0,
1718
+ 18.0,
1719
+ 7.0,
1720
+ 14.0,
1721
+ 21.0,
1722
+ 8.0,
1723
+ 16.0,
1724
+ 24.0,
1725
+ 9.0,
1726
+ 18.0,
1727
+ 27.0,
1728
+ 10.0,
1729
+ 20.0,
1730
+ 30.0,
4178
1731
  )
1732
+ ],
1733
+ devices=devices,
1734
+ )
4179
1735
 
4180
- for dtype in np_float_types:
4181
- add_function_test_register_kernel(
4182
- TestMat, f"test_quat_constructor_{dtype.__name__}", test_quat_constructor, devices=devices, dtype=dtype
4183
- )
4184
- add_function_test_register_kernel(
4185
- TestMat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
4186
- )
4187
- add_function_test_register_kernel(TestMat, f"test_svd_{dtype.__name__}", test_svd, devices=devices, dtype=dtype)
4188
- add_function_test_register_kernel(TestMat, f"test_qr_{dtype.__name__}", test_qr, devices=devices, dtype=dtype)
4189
- add_function_test_register_kernel(TestMat, f"test_eig_{dtype.__name__}", test_eig, devices=devices, dtype=dtype)
4190
- add_function_test_register_kernel(
4191
- TestMat, f"test_transform_point_{dtype.__name__}", test_transform_point, devices=devices, dtype=dtype
4192
- )
4193
- add_function_test_register_kernel(
4194
- TestMat, f"test_transform_vector_{dtype.__name__}", test_transform_vector, devices=devices, dtype=dtype
4195
- )
4196
- add_function_test_register_kernel(
4197
- TestMat, f"test_determinant_{dtype.__name__}", test_determinant, devices=devices, dtype=dtype
4198
- )
4199
- add_function_test_register_kernel(
4200
- TestMat, f"test_skew_{dtype.__name__}", test_skew, devices=devices, dtype=dtype
4201
- )
1736
+ for dtype in np_signed_int_types + np_float_types:
1737
+ add_function_test_register_kernel(
1738
+ TestMat, f"test_negation_{dtype.__name__}", test_negation, devices=devices, dtype=dtype
1739
+ )
1740
+ add_function_test_register_kernel(
1741
+ TestMat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
1742
+ )
4202
1743
 
4203
- return TestMat
1744
+ add_function_test(
1745
+ TestMat,
1746
+ "test_anon_constructor_error_shape_keyword_missing",
1747
+ test_anon_constructor_error_shape_keyword_missing,
1748
+ devices=devices,
1749
+ )
1750
+ add_function_test(
1751
+ TestMat,
1752
+ "test_anon_constructor_error_dtype_keyword_missing",
1753
+ test_anon_constructor_error_dtype_keyword_missing,
1754
+ devices=devices,
1755
+ )
1756
+ add_function_test(
1757
+ TestMat,
1758
+ "test_anon_constructor_error_shape_mismatch",
1759
+ test_anon_constructor_error_shape_mismatch,
1760
+ devices=devices,
1761
+ )
1762
+ add_function_test(
1763
+ TestMat,
1764
+ "test_anon_constructor_error_invalid_arg_count",
1765
+ test_anon_constructor_error_invalid_arg_count,
1766
+ devices=devices,
1767
+ )
1768
+ add_function_test(
1769
+ TestMat,
1770
+ "test_tpl_constructor_error_incompatible_sizes",
1771
+ test_tpl_constructor_error_incompatible_sizes,
1772
+ devices=devices,
1773
+ )
1774
+ add_function_test(
1775
+ TestMat,
1776
+ "test_tpl_constructor_error_invalid_scalar_type",
1777
+ test_tpl_constructor_error_invalid_scalar_type,
1778
+ devices=devices,
1779
+ )
1780
+ add_function_test(
1781
+ TestMat,
1782
+ "test_tpl_constructor_error_invalid_vector_count",
1783
+ test_tpl_constructor_error_invalid_vector_count,
1784
+ devices=devices,
1785
+ )
1786
+ add_function_test(
1787
+ TestMat,
1788
+ "test_tpl_constructor_error_invalid_vector_shape",
1789
+ test_tpl_constructor_error_invalid_vector_shape,
1790
+ devices=devices,
1791
+ )
1792
+ add_function_test(
1793
+ TestMat,
1794
+ "test_tpl_constructor_error_invalid_arg_count",
1795
+ test_tpl_constructor_error_invalid_arg_count,
1796
+ devices=devices,
1797
+ )
1798
+ add_function_test(TestMat, "test_tpl_ops_with_anon", test_tpl_ops_with_anon)
1799
+
1800
+ for dtype in np_float_types:
1801
+ add_function_test(
1802
+ TestMat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
1803
+ )
1804
+ add_function_test_register_kernel(
1805
+ TestMat, f"test_quat_constructor_{dtype.__name__}", test_quat_constructor, devices=devices, dtype=dtype
1806
+ )
1807
+ add_function_test_register_kernel(
1808
+ TestMat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
1809
+ )
1810
+ add_function_test_register_kernel(TestMat, f"test_svd_{dtype.__name__}", test_svd, devices=devices, dtype=dtype)
1811
+ add_function_test_register_kernel(TestMat, f"test_qr_{dtype.__name__}", test_qr, devices=devices, dtype=dtype)
1812
+ add_function_test_register_kernel(TestMat, f"test_eig_{dtype.__name__}", test_eig, devices=devices, dtype=dtype)
1813
+ add_function_test_register_kernel(
1814
+ TestMat, f"test_transform_point_{dtype.__name__}", test_transform_point, devices=devices, dtype=dtype
1815
+ )
1816
+ add_function_test_register_kernel(
1817
+ TestMat, f"test_transform_vector_{dtype.__name__}", test_transform_vector, devices=devices, dtype=dtype
1818
+ )
1819
+ add_function_test_register_kernel(
1820
+ TestMat, f"test_determinant_{dtype.__name__}", test_determinant, devices=devices, dtype=dtype
1821
+ )
1822
+ add_function_test_register_kernel(TestMat, f"test_skew_{dtype.__name__}", test_skew, devices=devices, dtype=dtype)
4204
1823
 
4205
1824
 
4206
1825
  if __name__ == "__main__":
4207
- c = register(unittest.TestCase)
1826
+ wp.build.clear_kernel_cache()
4208
1827
  unittest.main(verbosity=2, failfast=True)