warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/tests/test_mat.py CHANGED
@@ -8,8 +8,9 @@
8
8
  import unittest
9
9
 
10
10
  import numpy as np
11
+
11
12
  import warp as wp
12
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
13
14
 
14
15
  wp.init()
15
16
 
@@ -21,20 +22,8 @@ np_signed_int_types = [
21
22
  np.byte,
22
23
  ]
23
24
 
24
- np_unsigned_int_types = [
25
- np.uint8,
26
- np.uint16,
27
- np.uint32,
28
- np.uint64,
29
- np.ubyte,
30
- ]
31
-
32
- np_int_types = np_signed_int_types + np_unsigned_int_types
33
-
34
25
  np_float_types = [np.float16, np.float32, np.float64]
35
26
 
36
- np_scalar_types = np_int_types + np_float_types
37
-
38
27
 
39
28
  def randvals(rng, shape, dtype):
40
29
  if dtype in np_float_types:
@@ -64,374 +53,9 @@ def get_select_kernel(dtype):
64
53
 
65
54
  return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
66
55
 
67
-
68
- def test_arrays(test, device, dtype):
69
- rng = np.random.default_rng(123)
70
-
71
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
72
-
73
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
74
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
75
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
76
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
77
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
78
-
79
- v2_np = randvals(rng, [10, 2, 2], dtype)
80
- v3_np = randvals(rng, [10, 3, 3], dtype)
81
- v4_np = randvals(rng, [10, 4, 4], dtype)
82
- v5_np = randvals(rng, [10, 5, 5], dtype)
83
- v32_np = randvals(rng, [10, 3, 2], dtype)
84
-
85
- v2 = wp.array(v2_np, dtype=mat22, requires_grad=True, device=device)
86
- v3 = wp.array(v3_np, dtype=mat33, requires_grad=True, device=device)
87
- v4 = wp.array(v4_np, dtype=mat44, requires_grad=True, device=device)
88
- v5 = wp.array(v5_np, dtype=mat55, requires_grad=True, device=device)
89
- v32 = wp.array(v32_np, dtype=mat32, requires_grad=True, device=device)
90
-
91
- assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
92
- assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
93
- assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
94
- assert_np_equal(v5.numpy(), v5_np, tol=1.0e-6)
95
- assert_np_equal(v32.numpy(), v32_np, tol=1.0e-6)
96
-
97
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
98
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
99
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
100
-
101
- v2 = wp.array(v2_np, dtype=mat22, requires_grad=True, device=device)
102
- v3 = wp.array(v3_np, dtype=mat33, requires_grad=True, device=device)
103
- v4 = wp.array(v4_np, dtype=mat44, requires_grad=True, device=device)
104
-
105
- assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
106
- assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
107
- assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
108
-
109
-
110
- def test_components(test, device, dtype):
111
- # test accessing matrix components from Python - this is especially important
112
- # for float16, which requires special handling internally
113
-
114
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
115
- mat23 = wp.types.matrix(shape=(2, 3), dtype=wptype)
116
-
117
- m = mat23(1, 2, 3, 4, 5, 6)
118
-
119
- # test __getitem__ for row vectors
120
- r0 = m[0]
121
- r1 = m[1]
122
- test.assertEqual(r0[0], 1)
123
- test.assertEqual(r0[1], 2)
124
- test.assertEqual(r0[2], 3)
125
- test.assertEqual(r1[0], 4)
126
- test.assertEqual(r1[1], 5)
127
- test.assertEqual(r1[2], 6)
128
-
129
- # test __getitem__ for individual components
130
- test.assertEqual(m[0, 0], 1)
131
- test.assertEqual(m[0, 1], 2)
132
- test.assertEqual(m[0, 2], 3)
133
- test.assertEqual(m[1, 0], 4)
134
- test.assertEqual(m[1, 1], 5)
135
- test.assertEqual(m[1, 2], 6)
136
-
137
- # test __setitem__ for row vectors
138
- m[0] = [7, 8, 9]
139
- m[1] = [10, 11, 12]
140
- test.assertEqual(m[0, 0], 7)
141
- test.assertEqual(m[0, 1], 8)
142
- test.assertEqual(m[0, 2], 9)
143
- test.assertEqual(m[1, 0], 10)
144
- test.assertEqual(m[1, 1], 11)
145
- test.assertEqual(m[1, 2], 12)
146
-
147
- # test __setitem__ for individual components
148
- m[0, 0] = 13
149
- m[0, 1] = 14
150
- m[0, 2] = 15
151
- m[1, 0] = 16
152
- m[1, 1] = 17
153
- m[1, 2] = 18
154
- test.assertEqual(m[0, 0], 13)
155
- test.assertEqual(m[0, 1], 14)
156
- test.assertEqual(m[0, 2], 15)
157
- test.assertEqual(m[1, 0], 16)
158
- test.assertEqual(m[1, 1], 17)
159
- test.assertEqual(m[1, 2], 18)
160
-
161
-
162
- def test_constants(test, device, dtype, register_kernels=False):
163
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
164
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
165
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
166
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
167
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
168
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
169
-
170
- cm22 = wp.constant(mat22(22))
171
- cm33 = wp.constant(mat33(33))
172
- cm44 = wp.constant(mat44(44))
173
- cm55 = wp.constant(mat55(55))
174
- cm32 = wp.constant(mat32(32))
175
-
176
- def check_matrix_constants():
177
- wp.expect_eq(cm22, mat22(wptype(22)))
178
- wp.expect_eq(cm33, mat33(wptype(33)))
179
- wp.expect_eq(cm44, mat44(wptype(44)))
180
- wp.expect_eq(cm55, mat55(wptype(55)))
181
- wp.expect_eq(cm32, mat32(wptype(32)))
182
-
183
- kernel = getkernel(check_matrix_constants, suffix=dtype.__name__)
184
-
185
- if register_kernels:
186
- return
187
-
188
56
  wp.launch(kernel, dim=1, inputs=[])
189
57
 
190
58
 
191
- def test_constructors(test, device, dtype, register_kernels=False):
192
- rng = np.random.default_rng(123)
193
-
194
- tol = {
195
- np.float16: 1.0e-3,
196
- np.float32: 1.0e-6,
197
- np.float64: 1.0e-8,
198
- }.get(dtype, 0)
199
-
200
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
201
- vec2 = wp.types.vector(length=2, dtype=wptype)
202
- vec3 = wp.types.vector(length=3, dtype=wptype)
203
- vec4 = wp.types.vector(length=4, dtype=wptype)
204
- vec5 = wp.types.vector(length=5, dtype=wptype)
205
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
206
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
207
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
208
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
209
-
210
- output_select_kernel = get_select_kernel(wptype)
211
-
212
- def check_scalar_mat_constructor(
213
- input: wp.array(dtype=wptype),
214
- outcomponents: wp.array(dtype=wptype),
215
- ):
216
- # multiply outputs by 2 so we've got something to backpropagate:
217
- m2result = wptype(2) * mat22(input[0])
218
- m3result = wptype(2) * mat33(input[0])
219
- m4result = wptype(2) * mat44(input[0])
220
- m5result = wptype(2) * mat55(input[0])
221
-
222
- idx = 0
223
- for i in range(2):
224
- for j in range(2):
225
- outcomponents[idx] = m2result[i, j]
226
- idx = idx + 1
227
-
228
- for i in range(3):
229
- for j in range(3):
230
- outcomponents[idx] = m3result[i, j]
231
- idx = idx + 1
232
-
233
- for i in range(4):
234
- for j in range(4):
235
- outcomponents[idx] = m4result[i, j]
236
- idx = idx + 1
237
-
238
- for i in range(5):
239
- for j in range(5):
240
- outcomponents[idx] = m5result[i, j]
241
- idx = idx + 1
242
-
243
- def check_component_mat_constructor(
244
- input: wp.array(dtype=wptype),
245
- outcomponents: wp.array(dtype=wptype),
246
- ):
247
- # multiply outputs by 2 so we've got something to backpropagate:
248
- m2result = wptype(2) * mat22(input[0], input[1], input[2], input[3])
249
- m3result = wptype(2) * mat33(
250
- input[4],
251
- input[5],
252
- input[6],
253
- input[7],
254
- input[8],
255
- input[9],
256
- input[10],
257
- input[11],
258
- input[12],
259
- )
260
- m4result = wptype(2) * mat44(
261
- input[13],
262
- input[14],
263
- input[15],
264
- input[16],
265
- input[17],
266
- input[18],
267
- input[19],
268
- input[20],
269
- input[21],
270
- input[22],
271
- input[23],
272
- input[24],
273
- input[25],
274
- input[26],
275
- input[27],
276
- input[28],
277
- )
278
- m5result = wptype(2) * mat55(
279
- input[29],
280
- input[30],
281
- input[31],
282
- input[32],
283
- input[33],
284
- input[34],
285
- input[35],
286
- input[36],
287
- input[37],
288
- input[38],
289
- input[39],
290
- input[40],
291
- input[41],
292
- input[42],
293
- input[43],
294
- input[44],
295
- input[45],
296
- input[46],
297
- input[47],
298
- input[48],
299
- input[49],
300
- input[50],
301
- input[51],
302
- input[52],
303
- input[53],
304
- )
305
-
306
- idx = 0
307
- for i in range(2):
308
- for j in range(2):
309
- outcomponents[idx] = m2result[i, j]
310
- idx = idx + 1
311
-
312
- for i in range(3):
313
- for j in range(3):
314
- outcomponents[idx] = m3result[i, j]
315
- idx = idx + 1
316
-
317
- for i in range(4):
318
- for j in range(4):
319
- outcomponents[idx] = m4result[i, j]
320
- idx = idx + 1
321
-
322
- for i in range(5):
323
- for j in range(5):
324
- outcomponents[idx] = m5result[i, j]
325
- idx = idx + 1
326
-
327
- def check_vector_mat_constructor(
328
- input: wp.array(dtype=wptype),
329
- outcomponents: wp.array(dtype=wptype),
330
- ):
331
- # multiply outputs by 2 so we've got something to backpropagate:
332
- m2result = wptype(2) * mat22(vec2(input[0], input[2]), vec2(input[1], input[3]))
333
- m3result = wptype(2) * mat33(
334
- vec3(input[4], input[7], input[10]),
335
- vec3(input[5], input[8], input[11]),
336
- vec3(input[6], input[9], input[12]),
337
- )
338
- m4result = wptype(2) * mat44(
339
- vec4(input[13], input[17], input[21], input[25]),
340
- vec4(input[14], input[18], input[22], input[26]),
341
- vec4(input[15], input[19], input[23], input[27]),
342
- vec4(input[16], input[20], input[24], input[28]),
343
- )
344
- m5result = wptype(2) * mat55(
345
- vec5(input[29], input[34], input[39], input[44], input[49]),
346
- vec5(input[30], input[35], input[40], input[45], input[50]),
347
- vec5(input[31], input[36], input[41], input[46], input[51]),
348
- vec5(input[32], input[37], input[42], input[47], input[52]),
349
- vec5(input[33], input[38], input[43], input[48], input[53]),
350
- )
351
-
352
- idx = 0
353
- for i in range(2):
354
- for j in range(2):
355
- outcomponents[idx] = m2result[i, j]
356
- idx = idx + 1
357
-
358
- for i in range(3):
359
- for j in range(3):
360
- outcomponents[idx] = m3result[i, j]
361
- idx = idx + 1
362
-
363
- for i in range(4):
364
- for j in range(4):
365
- outcomponents[idx] = m4result[i, j]
366
- idx = idx + 1
367
-
368
- for i in range(5):
369
- for j in range(5):
370
- outcomponents[idx] = m5result[i, j]
371
- idx = idx + 1
372
-
373
- kernel = getkernel(check_scalar_mat_constructor, suffix=dtype.__name__)
374
- compkernel = getkernel(check_component_mat_constructor, suffix=dtype.__name__)
375
- veckernel = getkernel(check_vector_mat_constructor, suffix=dtype.__name__)
376
-
377
- if register_kernels:
378
- return
379
-
380
- input = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
381
- val = input.numpy()[0]
382
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
383
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
384
-
385
- wp.launch(kernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
386
-
387
- assert_np_equal(outcomponents.numpy()[:4], 2 * val * np.ones(2 * 2), tol=tol)
388
- assert_np_equal(outcomponents.numpy()[4:13], 2 * val * np.ones(3 * 3), tol=tol)
389
- assert_np_equal(outcomponents.numpy()[13:29], 2 * val * np.ones(4 * 4), tol=tol)
390
- assert_np_equal(outcomponents.numpy()[29:54], 2 * val * np.ones(5 * 5), tol=tol)
391
-
392
- if dtype in np_float_types:
393
- for idx in range(len(outcomponents)):
394
- tape = wp.Tape()
395
- with tape:
396
- wp.launch(kernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
397
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
398
- tape.backward(loss=out)
399
- test.assertEqual(tape.gradients[input].numpy()[0], 2)
400
- tape.zero()
401
-
402
- input = wp.array(randvals(rng, [2 * 2 + 3 * 3 + 4 * 4 + 5 * 5], dtype), requires_grad=True, device=device)
403
-
404
- wp.launch(compkernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
405
- assert_np_equal(2 * input.numpy(), outcomponents.numpy(), tol=10 * tol)
406
-
407
- if dtype in np_float_types:
408
- for idx in range(len(outcomponents)):
409
- tape = wp.Tape()
410
- with tape:
411
- wp.launch(compkernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
412
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
413
- tape.backward(loss=out)
414
- expectedgrads = np.zeros(len(input))
415
- expectedgrads[idx] = 2
416
- assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
417
- tape.zero()
418
-
419
- wp.launch(veckernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
420
- assert_np_equal(2 * input.numpy(), outcomponents.numpy(), tol=10 * tol)
421
-
422
- if dtype in np_float_types:
423
- for idx in range(len(outcomponents)):
424
- tape = wp.Tape()
425
- with tape:
426
- wp.launch(veckernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
427
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
428
- tape.backward(loss=out)
429
- expectedgrads = np.zeros(len(input))
430
- expectedgrads[idx] = 2
431
- assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
432
- tape.zero()
433
-
434
-
435
59
  def test_anon_constructor_error_shape_keyword_missing(test, device):
436
60
  @wp.kernel
437
61
  def kernel():
@@ -604,6 +228,47 @@ def test_tpl_ops_with_anon(test, device):
604
228
  test.assertSequenceEqual(m, ((0.0, 1.0), (2.0, 3.0)))
605
229
 
606
230
 
231
+ def test_py_arithmetic_ops(test, device, dtype):
232
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
233
+
234
+ def make_mat(*args):
235
+ if wptype in wp.types.int_types:
236
+ # Cast to the correct integer type to simulate wrapping.
237
+ return tuple(tuple(wptype._type_(x).value for x in row) for row in args)
238
+
239
+ return args
240
+
241
+ def make_vec(*args):
242
+ if wptype in wp.types.int_types:
243
+ # Cast to the correct integer type to simulate wrapping.
244
+ return tuple(wptype._type_(x).value for x in args)
245
+
246
+ return args
247
+
248
+ mat_cls = wp.mat((3, 3), wptype)
249
+ vec_cls = wp.vec(3, wptype)
250
+
251
+ m = mat_cls(((-1, 2, 3), (4, -5, 6), (7, 8, -9)))
252
+ test.assertSequenceEqual(+m, make_mat((-1, 2, 3), (4, -5, 6), (7, 8, -9)))
253
+ test.assertSequenceEqual(-m, make_mat((1, -2, -3), (-4, 5, -6), (-7, -8, 9)))
254
+ test.assertSequenceEqual(m + mat_cls((5, 5, 5) * 3), make_mat((4, 7, 8), (9, 0, 11), (12, 13, -4)))
255
+ test.assertSequenceEqual(m - mat_cls((5, 5, 5) * 3), make_mat((-6, -3, -2), (-1, -10, 1), (2, 3, -14)))
256
+ test.assertSequenceEqual(m * vec_cls(5, 5, 5), make_vec(20, 25, 30))
257
+ test.assertSequenceEqual(m @ vec_cls(5, 5, 5), make_vec(20, 25, 30))
258
+ test.assertSequenceEqual(vec_cls(5, 5, 5) * m, make_vec(50, 25, 0))
259
+ test.assertSequenceEqual(vec_cls(5, 5, 5) @ m, make_vec(50, 25, 0))
260
+
261
+ m = mat_cls(((2, 4, 6), (8, 10, 12), (14, 16, 18)))
262
+ test.assertSequenceEqual(m * wptype(2), make_mat((4, 8, 12), (16, 20, 24), (28, 32, 36)))
263
+ test.assertSequenceEqual(wptype(2) * m, make_mat((4, 8, 12), (16, 20, 24), (28, 32, 36)))
264
+ test.assertSequenceEqual(m / wptype(2), make_mat((1, 2, 3), (4, 5, 6), (7, 8, 9)))
265
+ test.assertSequenceEqual(wptype(5040) / m, make_mat((2520, 1260, 840), (630, 504, 420), (360, 315, 280)))
266
+ test.assertSequenceEqual(m * vec_cls(5, 5, 5), make_vec(60, 150, 240))
267
+ test.assertSequenceEqual(m @ vec_cls(5, 5, 5), make_vec(60, 150, 240))
268
+ test.assertSequenceEqual(vec_cls(5, 5, 5) * m, make_vec(120, 150, 180))
269
+ test.assertSequenceEqual(vec_cls(5, 5, 5) @ m, make_vec(120, 150, 180))
270
+
271
+
607
272
  def test_quat_constructor(test, device, dtype, register_kernels=False):
608
273
  rng = np.random.default_rng(123)
609
274
 
@@ -703,11 +368,11 @@ def test_quat_constructor(test, device, dtype, register_kernels=False):
703
368
  idx = idx + 1
704
369
 
705
370
 
706
- def test_indexing(test, device, dtype, register_kernels=False):
371
+ def test_negation(test, device, dtype, register_kernels=False):
707
372
  rng = np.random.default_rng(123)
708
373
 
709
374
  tol = {
710
- np.float16: 1.0e-3,
375
+ np.float16: 1.0e-2,
711
376
  np.float32: 1.0e-6,
712
377
  np.float64: 1.0e-8,
713
378
  }.get(dtype, 0)
@@ -720,36 +385,41 @@ def test_indexing(test, device, dtype, register_kernels=False):
720
385
 
721
386
  output_select_kernel = get_select_kernel(wptype)
722
387
 
723
- def check_mat_indexing(
388
+ def check_mat_negation(
724
389
  m2: wp.array(dtype=mat22),
725
390
  m3: wp.array(dtype=mat33),
726
391
  m4: wp.array(dtype=mat44),
727
392
  m5: wp.array(dtype=mat55),
728
393
  outcomponents: wp.array(dtype=wptype),
729
394
  ):
395
+ mat2 = -m2[0]
396
+ mat3 = -m3[0]
397
+ mat4 = -m4[0]
398
+ mat5 = -m5[0]
399
+
730
400
  # multiply outputs by 2 so we've got something to backpropagate:
731
401
  idx = 0
732
402
  for i in range(2):
733
403
  for j in range(2):
734
- outcomponents[idx] = wptype(2) * m2[0][i, j]
404
+ outcomponents[idx] = wptype(2) * mat2[i, j]
735
405
  idx = idx + 1
736
406
 
737
407
  for i in range(3):
738
408
  for j in range(3):
739
- outcomponents[idx] = wptype(2) * m3[0][i, j]
409
+ outcomponents[idx] = wptype(2) * mat3[i, j]
740
410
  idx = idx + 1
741
411
 
742
412
  for i in range(4):
743
413
  for j in range(4):
744
- outcomponents[idx] = wptype(2) * m4[0][i, j]
414
+ outcomponents[idx] = wptype(2) * mat4[i, j]
745
415
  idx = idx + 1
746
416
 
747
417
  for i in range(5):
748
418
  for j in range(5):
749
- outcomponents[idx] = wptype(2) * m5[0][i, j]
419
+ outcomponents[idx] = wptype(2) * mat5[i, j]
750
420
  idx = idx + 1
751
421
 
752
- kernel = getkernel(check_mat_indexing, suffix=dtype.__name__)
422
+ kernel = getkernel(check_mat_negation, suffix=dtype.__name__)
753
423
 
754
424
  if register_kernels:
755
425
  return
@@ -762,10 +432,10 @@ def test_indexing(test, device, dtype, register_kernels=False):
762
432
 
763
433
  wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
764
434
 
765
- assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy().reshape(-1), tol=tol)
766
- assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy().reshape(-1), tol=tol)
767
- assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy().reshape(-1), tol=tol)
768
- assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy().reshape(-1), tol=tol)
435
+ assert_np_equal(outcomponents.numpy()[:4], -2 * m2.numpy().reshape(-1), tol=tol)
436
+ assert_np_equal(outcomponents.numpy()[4:13], -2 * m3.numpy().reshape(-1), tol=tol)
437
+ assert_np_equal(outcomponents.numpy()[13:29], -2 * m4.numpy().reshape(-1), tol=tol)
438
+ assert_np_equal(outcomponents.numpy()[29:54], -2 * m5.numpy().reshape(-1), tol=tol)
769
439
 
770
440
  if dtype in np_float_types:
771
441
  idx = 0
@@ -781,283 +451,17 @@ def test_indexing(test, device, dtype, register_kernels=False):
781
451
  )
782
452
  tape.backward(loss=out)
783
453
  expectedresult = np.zeros((dim, dim), dtype=dtype)
784
- expectedresult[i, j] = 2
454
+ expectedresult[i, j] = -2
785
455
  assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
786
456
  tape.zero()
787
457
  idx = idx + 1
788
458
 
789
459
 
790
- def test_equality(test, device, dtype, register_kernels=False):
791
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
792
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
793
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
794
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
795
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
796
-
797
- def check_mat_equality():
798
- wp.expect_eq(
799
- mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
800
- mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
801
- )
802
- wp.expect_neq(
803
- mat22(wptype(1.0), wptype(2.0), wptype(3.0), -wptype(4.0)),
804
- mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
805
- )
806
-
807
- wp.expect_eq(
808
- mat33(
809
- wptype(1.0),
810
- wptype(2.0),
811
- wptype(3.0),
812
- wptype(4.0),
813
- wptype(5.0),
814
- wptype(6.0),
815
- wptype(7.0),
816
- wptype(8.0),
817
- wptype(9.0),
818
- ),
819
- mat33(
820
- wptype(1.0),
821
- wptype(2.0),
822
- wptype(3.0),
823
- wptype(4.0),
824
- wptype(5.0),
825
- wptype(6.0),
826
- wptype(7.0),
827
- wptype(8.0),
828
- wptype(9.0),
829
- ),
830
- )
831
- wp.expect_neq(
832
- mat33(
833
- wptype(1.0),
834
- wptype(2.0),
835
- wptype(3.0),
836
- wptype(4.0),
837
- wptype(5.0),
838
- wptype(6.0),
839
- wptype(7.0),
840
- wptype(8.0),
841
- wptype(9.0),
842
- ),
843
- mat33(
844
- wptype(1.0),
845
- wptype(2.0),
846
- wptype(3.0),
847
- -wptype(4.0),
848
- wptype(5.0),
849
- wptype(6.0),
850
- wptype(7.0),
851
- wptype(8.0),
852
- wptype(9.0),
853
- ),
854
- )
855
-
856
- wp.expect_eq(
857
- mat44(
858
- wptype(1.0),
859
- wptype(2.0),
860
- wptype(3.0),
861
- wptype(4.0),
862
- wptype(5.0),
863
- wptype(6.0),
864
- wptype(7.0),
865
- wptype(8.0),
866
- wptype(9.0),
867
- wptype(10.0),
868
- wptype(11.0),
869
- wptype(12.0),
870
- wptype(13.0),
871
- wptype(14.0),
872
- wptype(15.0),
873
- wptype(16.0),
874
- ),
875
- mat44(
876
- wptype(1.0),
877
- wptype(2.0),
878
- wptype(3.0),
879
- wptype(4.0),
880
- wptype(5.0),
881
- wptype(6.0),
882
- wptype(7.0),
883
- wptype(8.0),
884
- wptype(9.0),
885
- wptype(10.0),
886
- wptype(11.0),
887
- wptype(12.0),
888
- wptype(13.0),
889
- wptype(14.0),
890
- wptype(15.0),
891
- wptype(16.0),
892
- ),
893
- )
894
-
895
- wp.expect_neq(
896
- mat44(
897
- wptype(1.0),
898
- wptype(2.0),
899
- wptype(3.0),
900
- wptype(4.0),
901
- wptype(5.0),
902
- wptype(6.0),
903
- wptype(7.0),
904
- wptype(8.0),
905
- wptype(9.0),
906
- wptype(10.0),
907
- wptype(11.0),
908
- wptype(12.0),
909
- wptype(13.0),
910
- wptype(14.0),
911
- wptype(15.0),
912
- wptype(16.0),
913
- ),
914
- mat44(
915
- -wptype(1.0),
916
- wptype(2.0),
917
- wptype(3.0),
918
- wptype(4.0),
919
- wptype(5.0),
920
- wptype(6.0),
921
- wptype(7.0),
922
- wptype(8.0),
923
- wptype(9.0),
924
- wptype(10.0),
925
- wptype(11.0),
926
- wptype(12.0),
927
- wptype(13.0),
928
- wptype(14.0),
929
- wptype(15.0),
930
- wptype(16.0),
931
- ),
932
- )
933
-
934
- wp.expect_eq(
935
- mat55(
936
- wptype(1.0),
937
- wptype(2.0),
938
- wptype(3.0),
939
- wptype(4.0),
940
- wptype(5.0),
941
- wptype(6.0),
942
- wptype(7.0),
943
- wptype(8.0),
944
- wptype(9.0),
945
- wptype(10.0),
946
- wptype(11.0),
947
- wptype(12.0),
948
- wptype(13.0),
949
- wptype(14.0),
950
- wptype(15.0),
951
- wptype(16.0),
952
- wptype(17.0),
953
- wptype(18.0),
954
- wptype(19.0),
955
- wptype(20.0),
956
- wptype(21.0),
957
- wptype(22.0),
958
- wptype(23.0),
959
- wptype(24.0),
960
- wptype(25.0),
961
- ),
962
- mat55(
963
- wptype(1.0),
964
- wptype(2.0),
965
- wptype(3.0),
966
- wptype(4.0),
967
- wptype(5.0),
968
- wptype(6.0),
969
- wptype(7.0),
970
- wptype(8.0),
971
- wptype(9.0),
972
- wptype(10.0),
973
- wptype(11.0),
974
- wptype(12.0),
975
- wptype(13.0),
976
- wptype(14.0),
977
- wptype(15.0),
978
- wptype(16.0),
979
- wptype(17.0),
980
- wptype(18.0),
981
- wptype(19.0),
982
- wptype(20.0),
983
- wptype(21.0),
984
- wptype(22.0),
985
- wptype(23.0),
986
- wptype(24.0),
987
- wptype(25.0),
988
- ),
989
- )
990
-
991
- wp.expect_neq(
992
- mat55(
993
- wptype(1.0),
994
- wptype(2.0),
995
- wptype(3.0),
996
- wptype(4.0),
997
- wptype(5.0),
998
- wptype(6.0),
999
- wptype(7.0),
1000
- wptype(8.0),
1001
- wptype(9.0),
1002
- wptype(10.0),
1003
- wptype(11.0),
1004
- wptype(12.0),
1005
- wptype(13.0),
1006
- wptype(14.0),
1007
- wptype(15.0),
1008
- wptype(16.0),
1009
- wptype(17.0),
1010
- wptype(18.0),
1011
- wptype(19.0),
1012
- wptype(20.0),
1013
- wptype(21.0),
1014
- wptype(22.0),
1015
- wptype(23.0),
1016
- wptype(24.0),
1017
- wptype(25.0),
1018
- ),
1019
- mat55(
1020
- wptype(1.0),
1021
- wptype(2.0),
1022
- wptype(3.0),
1023
- wptype(4.0),
1024
- wptype(5.0),
1025
- wptype(6.0),
1026
- wptype(7.0),
1027
- wptype(8.0),
1028
- wptype(9.0),
1029
- wptype(10.0),
1030
- wptype(11.0),
1031
- wptype(12.0),
1032
- wptype(13.0),
1033
- wptype(14.0),
1034
- wptype(15.0),
1035
- wptype(16.0),
1036
- -wptype(17.0),
1037
- wptype(18.0),
1038
- wptype(19.0),
1039
- wptype(20.0),
1040
- wptype(21.0),
1041
- wptype(22.0),
1042
- wptype(23.0),
1043
- wptype(24.0),
1044
- wptype(25.0),
1045
- ),
1046
- )
1047
-
1048
- kernel = getkernel(check_mat_equality, suffix=dtype.__name__)
1049
-
1050
- if register_kernels:
1051
- return
1052
-
1053
- wp.launch(kernel, dim=1, inputs=[], outputs=[], device=device)
1054
-
1055
-
1056
- def test_negation(test, device, dtype, register_kernels=False):
460
+ def test_subtraction(test, device, dtype, register_kernels=False):
1057
461
  rng = np.random.default_rng(123)
1058
462
 
1059
463
  tol = {
1060
- np.float16: 1.0e-2,
464
+ np.float16: 5.0e-3,
1061
465
  np.float32: 1.0e-6,
1062
466
  np.float64: 1.0e-8,
1063
467
  }.get(dtype, 0)
@@ -1070,1584 +474,117 @@ def test_negation(test, device, dtype, register_kernels=False):
1070
474
 
1071
475
  output_select_kernel = get_select_kernel(wptype)
1072
476
 
1073
- def check_mat_negation(
1074
- m2: wp.array(dtype=mat22),
1075
- m3: wp.array(dtype=mat33),
1076
- m4: wp.array(dtype=mat44),
1077
- m5: wp.array(dtype=mat55),
477
+ def check_mat_sub(
478
+ s2: wp.array(dtype=mat22),
479
+ s3: wp.array(dtype=mat33),
480
+ s4: wp.array(dtype=mat44),
481
+ s5: wp.array(dtype=mat55),
482
+ v2: wp.array(dtype=mat22),
483
+ v3: wp.array(dtype=mat33),
484
+ v4: wp.array(dtype=mat44),
485
+ v5: wp.array(dtype=mat55),
1078
486
  outcomponents: wp.array(dtype=wptype),
1079
487
  ):
1080
- mat2 = -m2[0]
1081
- mat3 = -m3[0]
1082
- mat4 = -m4[0]
1083
- mat5 = -m5[0]
488
+ v2result = v2[0] - s2[0]
489
+ v3result = v3[0] - s3[0]
490
+ v4result = v4[0] - s4[0]
491
+ v5result = v5[0] - s5[0]
1084
492
 
1085
493
  # multiply outputs by 2 so we've got something to backpropagate:
1086
494
  idx = 0
1087
495
  for i in range(2):
1088
496
  for j in range(2):
1089
- outcomponents[idx] = wptype(2) * mat2[i, j]
497
+ outcomponents[idx] = wptype(2) * v2result[i, j]
1090
498
  idx = idx + 1
1091
499
 
1092
500
  for i in range(3):
1093
501
  for j in range(3):
1094
- outcomponents[idx] = wptype(2) * mat3[i, j]
502
+ outcomponents[idx] = wptype(2) * v3result[i, j]
1095
503
  idx = idx + 1
1096
504
 
1097
505
  for i in range(4):
1098
506
  for j in range(4):
1099
- outcomponents[idx] = wptype(2) * mat4[i, j]
507
+ outcomponents[idx] = wptype(2) * v4result[i, j]
1100
508
  idx = idx + 1
1101
509
 
1102
510
  for i in range(5):
1103
511
  for j in range(5):
1104
- outcomponents[idx] = wptype(2) * mat5[i, j]
512
+ outcomponents[idx] = wptype(2) * v5result[i, j]
1105
513
  idx = idx + 1
1106
514
 
1107
- kernel = getkernel(check_mat_negation, suffix=dtype.__name__)
515
+ kernel = getkernel(check_mat_sub, suffix=dtype.__name__)
1108
516
 
1109
517
  if register_kernels:
1110
518
  return
1111
519
 
1112
- m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1113
- m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1114
- m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1115
- m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1116
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
1117
-
1118
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
1119
-
1120
- assert_np_equal(outcomponents.numpy()[:4], -2 * m2.numpy().reshape(-1), tol=tol)
1121
- assert_np_equal(outcomponents.numpy()[4:13], -2 * m3.numpy().reshape(-1), tol=tol)
1122
- assert_np_equal(outcomponents.numpy()[13:29], -2 * m4.numpy().reshape(-1), tol=tol)
1123
- assert_np_equal(outcomponents.numpy()[29:54], -2 * m5.numpy().reshape(-1), tol=tol)
1124
-
1125
- if dtype in np_float_types:
1126
- idx = 0
1127
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1128
- for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
1129
- for i in range(dim):
1130
- for j in range(dim):
1131
- tape = wp.Tape()
1132
- with tape:
1133
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
1134
- wp.launch(
1135
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1136
- )
1137
- tape.backward(loss=out)
1138
- expectedresult = np.zeros((dim, dim), dtype=dtype)
1139
- expectedresult[i, j] = -2
1140
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
1141
- tape.zero()
1142
- idx = idx + 1
1143
-
1144
-
1145
- def test_transpose(test, device, dtype, register_kernels=False):
1146
- rng = np.random.default_rng(123)
1147
-
1148
- tol = {
1149
- np.float16: 1.0e-2,
1150
- np.float32: 1.0e-6,
1151
- np.float64: 1.0e-8,
1152
- }.get(dtype, 0)
1153
-
1154
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1155
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1156
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1157
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
1158
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1159
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1160
-
1161
- output_select_kernel = get_select_kernel(wptype)
1162
-
1163
- def check_mat_transpose(
1164
- m2: wp.array(dtype=mat22),
1165
- m3: wp.array(dtype=mat33),
1166
- m4: wp.array(dtype=mat44),
1167
- m5: wp.array(dtype=mat55),
1168
- m32: wp.array(dtype=mat32),
1169
- outcomponents: wp.array(dtype=wptype),
1170
- ):
1171
- # multiply outputs by 2 so we've got something to backpropagate:
1172
- mat2 = wptype(2) * wp.transpose(m2[0])
1173
- mat3 = wptype(2) * wp.transpose(m3[0])
1174
- mat4 = wptype(2) * wp.transpose(m4[0])
1175
- mat5 = wptype(2) * wp.transpose(m5[0])
1176
- mat32 = wptype(2) * wp.transpose(m32[0])
1177
-
1178
- idx = 0
1179
- for i in range(2):
1180
- for j in range(2):
1181
- outcomponents[idx] = mat2[i, j]
1182
- idx = idx + 1
1183
-
1184
- for i in range(3):
1185
- for j in range(3):
1186
- outcomponents[idx] = mat3[i, j]
1187
- idx = idx + 1
1188
-
1189
- for i in range(4):
1190
- for j in range(4):
1191
- outcomponents[idx] = mat4[i, j]
1192
- idx = idx + 1
1193
-
1194
- for i in range(5):
1195
- for j in range(5):
1196
- outcomponents[idx] = mat5[i, j]
1197
- idx = idx + 1
1198
-
1199
- for i in range(2):
1200
- for j in range(3):
1201
- outcomponents[idx] = mat32[i, j]
1202
- idx = idx + 1
1203
-
1204
- kernel = getkernel(check_mat_transpose, suffix=dtype.__name__)
1205
-
1206
- if register_kernels:
1207
- return
1208
-
1209
- m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1210
- m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1211
- m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1212
- m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1213
- m32 = wp.array(randvals(rng, [1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1214
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 2 * 3, dtype=wptype, requires_grad=True, device=device)
1215
-
1216
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1217
-
1218
- assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy()[0].T.reshape(-1), tol=tol)
1219
- assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy()[0].T.reshape(-1), tol=tol)
1220
- assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy()[0].T.reshape(-1), tol=tol)
1221
- assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy()[0].T.reshape(-1), tol=tol)
1222
- assert_np_equal(outcomponents.numpy()[54:], 2 * m32.numpy()[0].T.reshape(-1), tol=tol)
1223
-
1224
- if dtype in np_float_types:
1225
- idx = 0
1226
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1227
- for input in [m2, m3, m4, m5]:
1228
- for i in range(input.dtype._shape_[0]):
1229
- for j in range(input.dtype._shape_[1]):
1230
- tape = wp.Tape()
1231
- with tape:
1232
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1233
- wp.launch(
1234
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1235
- )
1236
- tape.backward(loss=out)
1237
- expectedresult = np.zeros((input.dtype._shape_[1], input.dtype._shape_[0]), dtype=dtype)
1238
- expectedresult[j, i] = 2
1239
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
1240
- tape.zero()
1241
- idx = idx + 1
1242
-
1243
-
1244
- def test_scalar_multiplication(test, device, dtype, register_kernels=False):
1245
- rng = np.random.default_rng(123)
1246
-
1247
- tol = {
1248
- np.float16: 1.0e-2,
1249
- np.float32: 1.0e-6,
1250
- np.float64: 1.0e-8,
1251
- }.get(dtype, 0)
1252
-
1253
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1254
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1255
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1256
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1257
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1258
-
1259
- output_select_kernel = get_select_kernel(wptype)
1260
-
1261
- def check_mat_scalar_mul(
1262
- s: wp.array(dtype=wptype),
1263
- m2: wp.array(dtype=mat22),
1264
- m3: wp.array(dtype=mat33),
1265
- m4: wp.array(dtype=mat44),
1266
- m5: wp.array(dtype=mat55),
1267
- outcomponents: wp.array(dtype=wptype),
1268
- outcomponents_rightmul: wp.array(dtype=wptype),
1269
- ):
1270
- m2result = s[0] * m2[0]
1271
- m3result = s[0] * m3[0]
1272
- m4result = s[0] * m4[0]
1273
- m5result = s[0] * m5[0]
1274
-
1275
- m2resultright = m2[0] * s[0]
1276
- m3resultright = m3[0] * s[0]
1277
- m4resultright = m4[0] * s[0]
1278
- m5resultright = m5[0] * s[0]
1279
-
1280
- m2result_2 = s[0] * m2[0]
1281
- m3result_2 = s[0] * m3[0]
1282
- m4result_2 = s[0] * m4[0]
1283
- m5result_2 = s[0] * m5[0]
1284
-
1285
- m2resultright_2 = m2[0] * s[0]
1286
- m3resultright_2 = m3[0] * s[0]
1287
- m4resultright_2 = m4[0] * s[0]
1288
- m5resultright_2 = m5[0] * s[0]
1289
-
1290
- # multiply outputs by 2 so we've got something to backpropagate:
1291
- idx = 0
1292
- for i in range(2):
1293
- for j in range(2):
1294
- outcomponents[idx] = wptype(2) * m2result[i, j]
1295
- outcomponents_rightmul[idx] = wptype(2) * m2resultright[i, j]
1296
- idx = idx + 1
1297
-
1298
- for i in range(3):
1299
- for j in range(3):
1300
- outcomponents[idx] = wptype(2) * m3result[i, j]
1301
- outcomponents_rightmul[idx] = wptype(2) * m3resultright[i, j]
1302
- idx = idx + 1
1303
-
1304
- for i in range(4):
1305
- for j in range(4):
1306
- outcomponents[idx] = wptype(2) * m4result[i, j]
1307
- outcomponents_rightmul[idx] = wptype(2) * m4resultright[i, j]
1308
- idx = idx + 1
1309
-
1310
- for i in range(5):
1311
- for j in range(5):
1312
- outcomponents[idx] = wptype(2) * m5result[i, j]
1313
- outcomponents_rightmul[idx] = wptype(2) * m5resultright[i, j]
1314
- idx = idx + 1
1315
-
1316
- for i in range(2):
1317
- for j in range(2):
1318
- outcomponents[idx] = wptype(2) * m2result_2[i, j]
1319
- outcomponents_rightmul[idx] = wptype(2) * m2resultright_2[i, j]
1320
- idx = idx + 1
1321
-
1322
- for i in range(3):
1323
- for j in range(3):
1324
- outcomponents[idx] = wptype(2) * m3result_2[i, j]
1325
- outcomponents_rightmul[idx] = wptype(2) * m3resultright_2[i, j]
1326
- idx = idx + 1
1327
-
1328
- for i in range(4):
1329
- for j in range(4):
1330
- outcomponents[idx] = wptype(2) * m4result_2[i, j]
1331
- outcomponents_rightmul[idx] = wptype(2) * m4resultright_2[i, j]
1332
- idx = idx + 1
1333
-
1334
- for i in range(5):
1335
- for j in range(5):
1336
- outcomponents[idx] = wptype(2) * m5result_2[i, j]
1337
- outcomponents_rightmul[idx] = wptype(2) * m5resultright_2[i, j]
1338
- idx = idx + 1
1339
-
1340
- kernel = getkernel(check_mat_scalar_mul, suffix=dtype.__name__)
1341
-
1342
- if register_kernels:
1343
- return
1344
-
1345
- s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
1346
- m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1347
- m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1348
- m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1349
- m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1350
- outcomponents = wp.zeros(2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5), dtype=wptype, requires_grad=True, device=device)
1351
- outcomponents_rightmul = wp.zeros(
1352
- 2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5), dtype=wptype, requires_grad=True, device=device
1353
- )
1354
-
1355
- wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents, outcomponents_rightmul], device=device)
1356
-
1357
- sval = s.numpy()[0]
1358
- assert_np_equal(outcomponents.numpy()[:4], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1359
- assert_np_equal(outcomponents.numpy()[4:13], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1360
- assert_np_equal(outcomponents.numpy()[13:29], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1361
- assert_np_equal(outcomponents.numpy()[29:54], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1362
-
1363
- assert_np_equal(outcomponents_rightmul.numpy()[:4], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1364
- assert_np_equal(outcomponents_rightmul.numpy()[4:13], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1365
- assert_np_equal(outcomponents_rightmul.numpy()[13:29], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1366
- assert_np_equal(outcomponents_rightmul.numpy()[29:54], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1367
-
1368
- assert_np_equal(outcomponents.numpy()[54:58], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1369
- assert_np_equal(outcomponents.numpy()[58:67], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1370
- assert_np_equal(outcomponents.numpy()[67:83], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1371
- assert_np_equal(outcomponents.numpy()[83:108], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1372
-
1373
- assert_np_equal(outcomponents_rightmul.numpy()[54:58], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1374
- assert_np_equal(outcomponents_rightmul.numpy()[58:67], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1375
- assert_np_equal(outcomponents_rightmul.numpy()[67:83], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1376
- assert_np_equal(outcomponents_rightmul.numpy()[83:108], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1377
-
1378
- if dtype in np_float_types:
1379
- idx = 0
1380
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1381
- for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
1382
- for i in range(dim):
1383
- for j in range(dim):
1384
- # test left mul gradient:
1385
- tape = wp.Tape()
1386
- with tape:
1387
- wp.launch(
1388
- kernel,
1389
- dim=1,
1390
- inputs=[s, m2, m3, m4, m5],
1391
- outputs=[outcomponents, outcomponents_rightmul],
1392
- device=device,
1393
- )
1394
- wp.launch(
1395
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1396
- )
1397
- tape.backward(loss=out)
1398
- expectedresult = np.zeros((dim, dim), dtype=dtype)
1399
- expectedresult[i, j] = 2 * sval
1400
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
1401
- assert_np_equal(tape.gradients[s].numpy()[0], 2 * input.numpy()[0, i, j], tol=10 * tol)
1402
- tape.zero()
1403
-
1404
- # test right mul gradient:
1405
- tape = wp.Tape()
1406
- with tape:
1407
- wp.launch(
1408
- kernel,
1409
- dim=1,
1410
- inputs=[s, m2, m3, m4, m5],
1411
- outputs=[outcomponents, outcomponents_rightmul],
1412
- device=device,
1413
- )
1414
- wp.launch(
1415
- output_select_kernel,
1416
- dim=1,
1417
- inputs=[outcomponents_rightmul, idx],
1418
- outputs=[out],
1419
- device=device,
1420
- )
1421
- tape.backward(loss=out)
1422
- expectedresult = np.zeros((dim, dim), dtype=dtype)
1423
- expectedresult[i, j] = 2 * sval
1424
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
1425
- assert_np_equal(tape.gradients[s].numpy()[0], 2 * input.numpy()[0, i, j], tol=10 * tol)
1426
- tape.zero()
1427
-
1428
- idx = idx + 1
1429
-
1430
-
1431
- def test_matvec_multiplication(test, device, dtype, register_kernels=False):
1432
- rng = np.random.default_rng(123)
1433
-
1434
- tol = {
1435
- np.float16: 2.0e-2,
1436
- np.float32: 5.0e-6,
1437
- np.float64: 1.0e-8,
1438
- }.get(dtype, 0)
1439
-
1440
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1441
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1442
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1443
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
1444
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1445
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1446
-
1447
- vec2 = wp.types.vector(length=2, dtype=wptype)
1448
- vec3 = wp.types.vector(length=3, dtype=wptype)
1449
- vec4 = wp.types.vector(length=4, dtype=wptype)
1450
- vec5 = wp.types.vector(length=5, dtype=wptype)
1451
-
1452
- output_select_kernel = get_select_kernel(wptype)
1453
-
1454
- def check_mat_vec_mul(
1455
- v2: wp.array(dtype=vec2),
1456
- v3: wp.array(dtype=vec3),
1457
- v4: wp.array(dtype=vec4),
1458
- v5: wp.array(dtype=vec5),
1459
- v32: wp.array(dtype=vec2),
1460
- m2: wp.array(dtype=mat22),
1461
- m3: wp.array(dtype=mat33),
1462
- m4: wp.array(dtype=mat44),
1463
- m5: wp.array(dtype=mat55),
1464
- m32: wp.array(dtype=mat32),
1465
- outcomponents: wp.array(dtype=wptype),
1466
- ):
1467
- v2result = m2[0] * v2[0]
1468
- v3result = m3[0] * v3[0]
1469
- v4result = m4[0] * v4[0]
1470
- v5result = m5[0] * v5[0]
1471
- v32result = m32[0] * v32[0]
1472
- v2result_2 = m2[0] @ v2[0]
1473
- v3result_2 = m3[0] @ v3[0]
1474
- v4result_2 = m4[0] @ v4[0]
1475
- v5result_2 = m5[0] @ v5[0]
1476
- v32result_2 = m32[0] @ v32[0]
1477
-
1478
- idx = 0
1479
-
1480
- # multiply outputs by 2 so we've got something to backpropagate:
1481
- for i in range(2):
1482
- outcomponents[idx] = wptype(2) * v2result[i]
1483
- idx = idx + 1
1484
-
1485
- for i in range(3):
1486
- outcomponents[idx] = wptype(2) * v3result[i]
1487
- idx = idx + 1
1488
-
1489
- for i in range(4):
1490
- outcomponents[idx] = wptype(2) * v4result[i]
1491
- idx = idx + 1
1492
-
1493
- for i in range(5):
1494
- outcomponents[idx] = wptype(2) * v5result[i]
1495
- idx = idx + 1
1496
-
1497
- for i in range(3):
1498
- outcomponents[idx] = wptype(2) * v32result[i]
1499
- idx = idx + 1
1500
-
1501
- for i in range(2):
1502
- outcomponents[idx] = wptype(2) * v2result_2[i]
1503
- idx = idx + 1
1504
-
1505
- for i in range(3):
1506
- outcomponents[idx] = wptype(2) * v3result_2[i]
1507
- idx = idx + 1
1508
-
1509
- for i in range(4):
1510
- outcomponents[idx] = wptype(2) * v4result_2[i]
1511
- idx = idx + 1
1512
-
1513
- for i in range(5):
1514
- outcomponents[idx] = wptype(2) * v5result_2[i]
1515
- idx = idx + 1
1516
-
1517
- for i in range(3):
1518
- outcomponents[idx] = wptype(2) * v32result_2[i]
1519
- idx = idx + 1
1520
-
1521
- kernel = getkernel(check_mat_vec_mul, suffix=dtype.__name__)
1522
-
1523
- if register_kernels:
1524
- return
1525
-
1526
- v2 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1527
- v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1528
- v4 = wp.array(randvals(rng, [1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
1529
- v5 = wp.array(randvals(rng, [1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
1530
- v32 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1531
- m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1532
- m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1533
- m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1534
- m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1535
- m32 = wp.array(randvals(rng, [1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1536
- outcomponents = wp.zeros(2 * (2 + 3 + 4 + 5 + 3), dtype=wptype, requires_grad=True, device=device)
1537
-
1538
- wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1539
-
1540
- assert_np_equal(outcomponents.numpy()[:2], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1541
- assert_np_equal(outcomponents.numpy()[2:5], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1542
- assert_np_equal(outcomponents.numpy()[5:9], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=5 * tol)
1543
- assert_np_equal(outcomponents.numpy()[9:14], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=5 * tol)
1544
- assert_np_equal(outcomponents.numpy()[14:17], 2 * np.matmul(m32.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1545
- assert_np_equal(outcomponents.numpy()[17:19], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1546
- assert_np_equal(outcomponents.numpy()[19:22], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1547
- assert_np_equal(outcomponents.numpy()[22:26], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=5 * tol)
1548
- assert_np_equal(outcomponents.numpy()[26:31], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=5 * tol)
1549
- assert_np_equal(outcomponents.numpy()[31:34], 2 * np.matmul(m32.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1550
-
1551
- if dtype in np_float_types:
1552
- idx = 0
1553
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1554
- for dim, invec, inmat in [(2, v2, m2), (3, v3, m3), (4, v4, m4), (5, v5, m5), (3, v32, m32)]:
1555
- for i in range(dim):
1556
- tape = wp.Tape()
1557
- with tape:
1558
- wp.launch(
1559
- kernel,
1560
- dim=1,
1561
- inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32],
1562
- outputs=[outcomponents],
1563
- device=device,
1564
- )
1565
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1566
- tape.backward(loss=out)
1567
-
1568
- assert_np_equal(tape.gradients[invec].numpy()[0], 2 * inmat.numpy()[0, i, :], tol=2 * tol)
1569
- expectedresult = np.zeros(inmat.dtype._shape_, dtype=dtype)
1570
- expectedresult[i, :] = 2 * invec.numpy()[0]
1571
- assert_np_equal(tape.gradients[inmat].numpy()[0], expectedresult, tol=2 * tol)
1572
-
1573
- tape.zero()
1574
-
1575
- idx = idx + 1
1576
-
1577
-
1578
- def test_matmat_multiplication(test, device, dtype, register_kernels=False):
1579
- rng = np.random.default_rng(123)
1580
-
1581
- tol = {
1582
- np.float16: 2.0e-2,
1583
- np.float32: 5.0e-6,
1584
- np.float64: 1.0e-8,
1585
- }.get(dtype, 0)
1586
-
1587
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1588
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1589
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1590
- mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
1591
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1592
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1593
-
1594
- output_select_kernel = get_select_kernel(wptype)
1595
-
1596
- def check_mat_mat_mul(
1597
- a2: wp.array(dtype=mat22),
1598
- a3: wp.array(dtype=mat33),
1599
- a4: wp.array(dtype=mat44),
1600
- a5: wp.array(dtype=mat55),
1601
- a32: wp.array(dtype=mat32),
1602
- b2: wp.array(dtype=mat22),
1603
- b3: wp.array(dtype=mat33),
1604
- b4: wp.array(dtype=mat44),
1605
- b5: wp.array(dtype=mat55),
1606
- b32: wp.array(dtype=mat32),
1607
- outcomponents: wp.array(dtype=wptype),
1608
- ):
1609
- c2result = b2[0] * a2[0]
1610
- c3result = b3[0] * a3[0]
1611
- c4result = b4[0] * a4[0]
1612
- c5result = b5[0] * a5[0]
1613
- c32result = b32[0] * a2[0]
1614
- c32result2 = b3[0] * a32[0]
1615
- c2result_2 = b2[0] @ a2[0]
1616
- c3result_2 = b3[0] @ a3[0]
1617
- c4result_2 = b4[0] @ a4[0]
1618
- c5result_2 = b5[0] @ a5[0]
1619
- c32result_2 = b32[0] @ a2[0]
1620
- c32result2_2 = b3[0] @ a32[0]
1621
-
1622
- # multiply outputs by 2 so we've got something to backpropagate:
1623
- idx = 0
1624
- for i in range(2):
1625
- for j in range(2):
1626
- outcomponents[idx] = wptype(2) * c2result[i, j]
1627
- idx = idx + 1
1628
-
1629
- for i in range(3):
1630
- for j in range(3):
1631
- outcomponents[idx] = wptype(2) * c3result[i, j]
1632
- idx = idx + 1
1633
-
1634
- for i in range(4):
1635
- for j in range(4):
1636
- outcomponents[idx] = wptype(2) * c4result[i, j]
1637
- idx = idx + 1
1638
-
1639
- for i in range(5):
1640
- for j in range(5):
1641
- outcomponents[idx] = wptype(2) * c5result[i, j]
1642
- idx = idx + 1
1643
-
1644
- for i in range(3):
1645
- for j in range(2):
1646
- outcomponents[idx] = wptype(2) * c32result[i, j]
1647
- idx = idx + 1
1648
-
1649
- for i in range(3):
1650
- for j in range(2):
1651
- outcomponents[idx] = wptype(2) * c32result2[i, j]
1652
- idx = idx + 1
1653
-
1654
- for i in range(2):
1655
- for j in range(2):
1656
- outcomponents[idx] = wptype(2) * c2result_2[i, j]
1657
- idx = idx + 1
1658
-
1659
- for i in range(3):
1660
- for j in range(3):
1661
- outcomponents[idx] = wptype(2) * c3result_2[i, j]
1662
- idx = idx + 1
1663
-
1664
- for i in range(4):
1665
- for j in range(4):
1666
- outcomponents[idx] = wptype(2) * c4result_2[i, j]
1667
- idx = idx + 1
1668
-
1669
- for i in range(5):
1670
- for j in range(5):
1671
- outcomponents[idx] = wptype(2) * c5result_2[i, j]
1672
- idx = idx + 1
1673
-
1674
- for i in range(3):
1675
- for j in range(2):
1676
- outcomponents[idx] = wptype(2) * c32result_2[i, j]
1677
- idx = idx + 1
1678
-
1679
- for i in range(3):
1680
- for j in range(2):
1681
- outcomponents[idx] = wptype(2) * c32result2_2[i, j]
1682
- idx = idx + 1
1683
-
1684
- kernel = getkernel(check_mat_mat_mul, suffix=dtype.__name__)
1685
-
1686
- if register_kernels:
1687
- return
1688
-
1689
- v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1690
- v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1691
- v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1692
- v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1693
- v32 = wp.array(randvals(rng, [1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1694
- m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1695
- m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1696
- m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1697
- m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1698
- m32 = wp.array(randvals(rng, [1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1699
- outcomponents = wp.zeros(
1700
- 2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2 + 3 * 2), dtype=wptype, requires_grad=True, device=device
1701
- )
1702
-
1703
- wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1704
-
1705
- assert_np_equal(outcomponents.numpy()[:4], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1706
- assert_np_equal(outcomponents.numpy()[4:13], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1707
- assert_np_equal(outcomponents.numpy()[13:29], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=2 * tol)
1708
- assert_np_equal(outcomponents.numpy()[29:54], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=10 * tol)
1709
- assert_np_equal(outcomponents.numpy()[54:60], 2 * np.matmul(m32.numpy()[0], v2.numpy()[0]), tol=5 * tol)
1710
- assert_np_equal(outcomponents.numpy()[60:66], 2 * np.matmul(m3.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1711
- assert_np_equal(outcomponents.numpy()[66:70], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1712
- assert_np_equal(outcomponents.numpy()[70:79], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1713
- assert_np_equal(outcomponents.numpy()[79:95], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=2 * tol)
1714
- assert_np_equal(outcomponents.numpy()[95:120], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=10 * tol)
1715
- assert_np_equal(outcomponents.numpy()[120:126], 2 * np.matmul(m32.numpy()[0], v2.numpy()[0]), tol=5 * tol)
1716
- assert_np_equal(outcomponents.numpy()[126:132], 2 * np.matmul(m3.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1717
-
1718
- if dtype in np_float_types:
1719
- idx = 0
1720
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1721
- for v, m in [(v2, m2), (v3, m3), (v4, m4), (v5, m5), (v2, m32), (v32, m3)]:
1722
- rows, cols = m.dtype._shape_[0], v.dtype._shape_[1]
1723
- for i in range(rows):
1724
- for j in range(cols):
1725
- tape = wp.Tape()
1726
- with tape:
1727
- wp.launch(
1728
- kernel,
1729
- dim=1,
1730
- inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32],
1731
- outputs=[outcomponents],
1732
- device=device,
1733
- )
1734
- wp.launch(
1735
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1736
- )
1737
- tape.backward(loss=out)
1738
-
1739
- expected = np.zeros(v.dtype._shape_, dtype=dtype)
1740
- expected[:, j] = 2 * m.numpy()[0, i, :]
1741
- assert_np_equal(tape.gradients[v].numpy()[0], expected, tol=10 * tol)
1742
-
1743
- expected = np.zeros(m.dtype._shape_, dtype=dtype)
1744
- expected[i, :] = 2 * v.numpy()[0, :, j]
1745
- assert_np_equal(tape.gradients[m].numpy()[0], expected, tol=10 * tol)
1746
-
1747
- tape.zero()
1748
- idx = idx + 1
1749
-
1750
-
1751
- def test_cw_multiplication(test, device, dtype, register_kernels=False):
1752
- rng = np.random.default_rng(123)
1753
-
1754
- tol = {
1755
- np.float16: 5.0e-2,
1756
- np.float32: 1.0e-6,
1757
- np.float64: 1.0e-8,
1758
- }.get(dtype, 0)
1759
-
1760
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1761
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1762
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1763
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1764
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1765
-
1766
- output_select_kernel = get_select_kernel(wptype)
1767
-
1768
- def check_mat_cw_mul(
1769
- s2: wp.array(dtype=mat22),
1770
- s3: wp.array(dtype=mat33),
1771
- s4: wp.array(dtype=mat44),
1772
- s5: wp.array(dtype=mat55),
1773
- v2: wp.array(dtype=mat22),
1774
- v3: wp.array(dtype=mat33),
1775
- v4: wp.array(dtype=mat44),
1776
- v5: wp.array(dtype=mat55),
1777
- outcomponents: wp.array(dtype=wptype),
1778
- ):
1779
- v2result = wptype(2) * wp.cw_mul(v2[0], s2[0])
1780
- v3result = wptype(2) * wp.cw_mul(v3[0], s3[0])
1781
- v4result = wptype(2) * wp.cw_mul(v4[0], s4[0])
1782
- v5result = wptype(2) * wp.cw_mul(v5[0], s5[0])
1783
-
1784
- # multiply outputs by 2 so we've got something to backpropagate:
1785
- idx = 0
1786
- for i in range(2):
1787
- for j in range(2):
1788
- outcomponents[idx] = v2result[i, j]
1789
- idx = idx + 1
1790
-
1791
- for i in range(3):
1792
- for j in range(3):
1793
- outcomponents[idx] = v3result[i, j]
1794
- idx = idx + 1
1795
-
1796
- for i in range(4):
1797
- for j in range(4):
1798
- outcomponents[idx] = v4result[i, j]
1799
- idx = idx + 1
1800
-
1801
- for i in range(5):
1802
- for j in range(5):
1803
- outcomponents[idx] = v5result[i, j]
1804
- idx = idx + 1
1805
-
1806
- kernel = getkernel(check_mat_cw_mul, suffix=dtype.__name__)
1807
-
1808
- if register_kernels:
1809
- return
1810
-
1811
- s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1812
- s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1813
- s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1814
- s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1815
- v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1816
- v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1817
- v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1818
- v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1819
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
1820
-
1821
- wp.launch(
1822
- kernel,
1823
- dim=1,
1824
- inputs=[
1825
- s2,
1826
- s3,
1827
- s4,
1828
- s5,
1829
- v2,
1830
- v3,
1831
- v4,
1832
- v5,
1833
- ],
1834
- outputs=[outcomponents],
1835
- device=device,
1836
- )
1837
-
1838
- assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() * s2.numpy()).reshape(-1), tol=50 * tol)
1839
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() * s3.numpy()).reshape(-1), tol=50 * tol)
1840
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() * s4.numpy()).reshape(-1), tol=50 * tol)
1841
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() * s5.numpy()).reshape(-1), tol=50 * tol)
1842
-
1843
- if dtype in np_float_types:
1844
- idx = 0
1845
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1846
- for dim, in1, in2 in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
1847
- for i in range(dim):
1848
- for j in range(dim):
1849
- tape = wp.Tape()
1850
- with tape:
1851
- wp.launch(
1852
- kernel,
1853
- dim=1,
1854
- inputs=[
1855
- s2,
1856
- s3,
1857
- s4,
1858
- s5,
1859
- v2,
1860
- v3,
1861
- v4,
1862
- v5,
1863
- ],
1864
- outputs=[outcomponents],
1865
- device=device,
1866
- )
1867
- wp.launch(
1868
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1869
- )
1870
- tape.backward(loss=out)
1871
- expectedresult = np.zeros((dim, dim), dtype=dtype)
1872
- expectedresult[i, j] = 2 * in1.numpy()[0][i, j]
1873
- assert_np_equal(tape.gradients[in2].numpy()[0], expectedresult, tol=5 * tol)
1874
- expectedresult[i, j] = 2 * in2.numpy()[0][i, j]
1875
- assert_np_equal(tape.gradients[in1].numpy()[0], expectedresult, tol=5 * tol)
1876
- tape.zero()
1877
-
1878
- idx = idx + 1
1879
-
1880
-
1881
- def test_cw_division(test, device, dtype, register_kernels=False):
1882
- rng = np.random.default_rng(123)
1883
-
1884
- tol = {
1885
- np.float16: 1.0e-2,
1886
- np.float32: 1.0e-6,
1887
- np.float64: 1.0e-8,
1888
- }.get(dtype, 0)
1889
-
1890
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1891
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1892
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1893
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1894
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1895
-
1896
- output_select_kernel = get_select_kernel(wptype)
1897
-
1898
- def check_mat_cw_div(
1899
- s2: wp.array(dtype=mat22),
1900
- s3: wp.array(dtype=mat33),
1901
- s4: wp.array(dtype=mat44),
1902
- s5: wp.array(dtype=mat55),
1903
- v2: wp.array(dtype=mat22),
1904
- v3: wp.array(dtype=mat33),
1905
- v4: wp.array(dtype=mat44),
1906
- v5: wp.array(dtype=mat55),
1907
- outcomponents: wp.array(dtype=wptype),
1908
- ):
1909
- v2result = wptype(2) * wp.cw_div(v2[0], s2[0])
1910
- v3result = wptype(2) * wp.cw_div(v3[0], s3[0])
1911
- v4result = wptype(2) * wp.cw_div(v4[0], s4[0])
1912
- v5result = wptype(2) * wp.cw_div(v5[0], s5[0])
1913
-
1914
- # multiply outputs by 2 so we've got something to backpropagate:
1915
- idx = 0
1916
- for i in range(2):
1917
- for j in range(2):
1918
- outcomponents[idx] = v2result[i, j]
1919
- idx = idx + 1
1920
-
1921
- for i in range(3):
1922
- for j in range(3):
1923
- outcomponents[idx] = v3result[i, j]
1924
- idx = idx + 1
1925
-
1926
- for i in range(4):
1927
- for j in range(4):
1928
- outcomponents[idx] = v4result[i, j]
1929
- idx = idx + 1
1930
-
1931
- for i in range(5):
1932
- for j in range(5):
1933
- outcomponents[idx] = v5result[i, j]
1934
- idx = idx + 1
1935
-
1936
- kernel = getkernel(check_mat_cw_div, suffix=dtype.__name__)
1937
-
1938
- if register_kernels:
1939
- return
1940
-
1941
- s2 = randvals(rng, [1, 2, 2], dtype)
1942
- s3 = randvals(rng, [1, 3, 3], dtype)
1943
- s4 = randvals(rng, [1, 4, 4], dtype)
1944
- s5 = randvals(rng, [1, 5, 5], dtype)
1945
-
1946
- # set denominators to 1 if their magnitudes are small
1947
- # to prevent divide by zero, or overflows if we're testing
1948
- # float16:
1949
- s2[np.abs(s2) < 1.0e-2] = 1
1950
- s3[np.abs(s3) < 1.0e-2] = 1
1951
- s4[np.abs(s4) < 1.0e-2] = 1
1952
- s5[np.abs(s5) < 1.0e-2] = 1
1953
-
1954
- s2 = wp.array(s2, dtype=mat22, requires_grad=True, device=device)
1955
- s3 = wp.array(s3, dtype=mat33, requires_grad=True, device=device)
1956
- s4 = wp.array(s4, dtype=mat44, requires_grad=True, device=device)
1957
- s5 = wp.array(s5, dtype=mat55, requires_grad=True, device=device)
1958
-
1959
- v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1960
- v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1961
- v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1962
- v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1963
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
1964
-
1965
- wp.launch(
1966
- kernel,
1967
- dim=1,
1968
- inputs=[
1969
- s2,
1970
- s3,
1971
- s4,
1972
- s5,
1973
- v2,
1974
- v3,
1975
- v4,
1976
- v5,
1977
- ],
1978
- outputs=[outcomponents],
1979
- device=device,
1980
- )
1981
-
1982
- if dtype in np_float_types:
1983
- assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() / s2.numpy()).reshape(-1), tol=50 * tol)
1984
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() / s3.numpy()).reshape(-1), tol=50 * tol)
1985
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() / s4.numpy()).reshape(-1), tol=50 * tol)
1986
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() / s5.numpy()).reshape(-1), tol=50 * tol)
1987
- else:
1988
- assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() // s2.numpy()).reshape(-1), tol=50 * tol)
1989
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() // s3.numpy()).reshape(-1), tol=50 * tol)
1990
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() // s4.numpy()).reshape(-1), tol=50 * tol)
1991
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() // s5.numpy()).reshape(-1), tol=50 * tol)
1992
-
1993
- if dtype in np_float_types:
1994
- idx = 0
1995
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1996
- for dim, s, v in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
1997
- for i in range(dim):
1998
- for j in range(dim):
1999
- tape = wp.Tape()
2000
- with tape:
2001
- wp.launch(
2002
- kernel,
2003
- dim=1,
2004
- inputs=[
2005
- s2,
2006
- s3,
2007
- s4,
2008
- s5,
2009
- v2,
2010
- v3,
2011
- v4,
2012
- v5,
2013
- ],
2014
- outputs=[outcomponents],
2015
- device=device,
2016
- )
2017
- wp.launch(
2018
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2019
- )
2020
- tape.backward(loss=out)
2021
-
2022
- # y = v/s
2023
- # dy/dv = 1.0/s
2024
- # dy/ds = -v/s^2
2025
-
2026
- expectedresult = np.zeros((dim, dim), dtype=dtype)
2027
- expectedresult[i, j] = 2.0 / (s.numpy()[0, i, j])
2028
- assert_np_equal(tape.gradients[v].numpy()[0], expectedresult, tol=50 * tol)
2029
- expectedresult[i, j] = -2.0 * v.numpy()[0, i, j] / (s.numpy()[0, i, j] ** 2)
2030
- assert_np_equal(
2031
- tape.gradients[s].numpy()[0], expectedresult, tol=abs(outcomponents.numpy()[idx]) * 50 * tol
2032
- )
2033
- tape.zero()
2034
-
2035
- idx = idx + 1
2036
-
2037
-
2038
- def test_outer_product(test, device, dtype, register_kernels=False):
2039
- rng = np.random.default_rng(123)
2040
-
2041
- tol = {
2042
- np.float16: 5.0e-3,
2043
- np.float32: 1.0e-6,
2044
- np.float64: 1.0e-8,
2045
- }.get(dtype, 0)
2046
-
2047
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2048
- vec2 = wp.types.vector(length=2, dtype=wptype)
2049
- vec3 = wp.types.vector(length=3, dtype=wptype)
2050
- vec4 = wp.types.vector(length=4, dtype=wptype)
2051
- vec5 = wp.types.vector(length=5, dtype=wptype)
2052
-
2053
- output_select_kernel = get_select_kernel(wptype)
2054
-
2055
- def check_mat_outer_product(
2056
- s2: wp.array(dtype=vec2),
2057
- s3: wp.array(dtype=vec3),
2058
- s4: wp.array(dtype=vec4),
2059
- s5: wp.array(dtype=vec5),
2060
- v2: wp.array(dtype=vec2),
2061
- v3: wp.array(dtype=vec3),
2062
- v4: wp.array(dtype=vec4),
2063
- v5: wp.array(dtype=vec5),
2064
- outcomponents: wp.array(dtype=wptype),
2065
- ):
2066
- m22result = wptype(2) * wp.outer(s2[0], v2[0])
2067
- m33result = wptype(2) * wp.outer(s3[0], v3[0])
2068
- m44result = wptype(2) * wp.outer(s4[0], v4[0])
2069
- m55result = wptype(2) * wp.outer(s5[0], v5[0])
2070
- m25result = wptype(2) * wp.outer(s2[0], v5[0])
2071
-
2072
- # multiply outputs by 2 so we've got something to backpropagate:
2073
- idx = 0
2074
- for i in range(2):
2075
- for j in range(2):
2076
- outcomponents[idx] = m22result[i, j]
2077
- idx = idx + 1
2078
-
2079
- for i in range(3):
2080
- for j in range(3):
2081
- outcomponents[idx] = m33result[i, j]
2082
- idx = idx + 1
2083
-
2084
- for i in range(4):
2085
- for j in range(4):
2086
- outcomponents[idx] = m44result[i, j]
2087
- idx = idx + 1
2088
-
2089
- for i in range(5):
2090
- for j in range(5):
2091
- outcomponents[idx] = m55result[i, j]
2092
- idx = idx + 1
2093
-
2094
- for i in range(2):
2095
- for j in range(5):
2096
- outcomponents[idx] = m25result[i, j]
2097
- idx = idx + 1
2098
-
2099
- kernel = getkernel(check_mat_outer_product, suffix=dtype.__name__)
2100
-
2101
- if register_kernels:
2102
- return
2103
-
2104
- s2 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
2105
- s3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
2106
- s4 = wp.array(randvals(rng, [1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
2107
- s5 = wp.array(randvals(rng, [1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
2108
- v2 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
2109
- v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
2110
- v4 = wp.array(randvals(rng, [1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
2111
- v5 = wp.array(randvals(rng, [1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
2112
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 2 * 5, dtype=wptype, requires_grad=True, device=device)
2113
-
2114
- wp.launch(kernel, dim=1, inputs=[s2, s3, s4, s5, v2, v3, v4, v5], outputs=[outcomponents], device=device)
2115
-
2116
- assert_np_equal(outcomponents.numpy()[:4], 2 * s2.numpy()[0, :, None] * v2.numpy()[0, None, :], tol=tol)
2117
- assert_np_equal(outcomponents.numpy()[4:13], 2 * s3.numpy()[0, :, None] * v3.numpy()[0, None, :], tol=10 * tol)
2118
- assert_np_equal(outcomponents.numpy()[13:29], 2 * s4.numpy()[0, :, None] * v4.numpy()[0, None, :], tol=10 * tol)
2119
- assert_np_equal(outcomponents.numpy()[29:54], 2 * s5.numpy()[0, :, None] * v5.numpy()[0, None, :], tol=10 * tol)
2120
- assert_np_equal(outcomponents.numpy()[54:], 2 * s2.numpy()[0, :, None] * v5.numpy()[0, None, :], tol=10 * tol)
2121
-
2122
- if dtype in np_float_types:
2123
- idx = 0
2124
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2125
- for s, v in [(s2, v2), (s3, v3), (s4, v4), (s5, v5), (s2, v5)]:
2126
- rows = s.dtype._length_
2127
- cols = v.dtype._length_
2128
- for i in range(rows):
2129
- for j in range(cols):
2130
- tape = wp.Tape()
2131
- with tape:
2132
- wp.launch(
2133
- kernel,
2134
- dim=1,
2135
- inputs=[
2136
- s2,
2137
- s3,
2138
- s4,
2139
- s5,
2140
- v2,
2141
- v3,
2142
- v4,
2143
- v5,
2144
- ],
2145
- outputs=[outcomponents],
2146
- device=device,
2147
- )
2148
- wp.launch(
2149
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2150
- )
2151
- tape.backward(loss=out)
2152
-
2153
- # this component's gonna be s_i * v_j, so its s gradient is gonna be nozero
2154
- # at the ith component and its v gradient will be nonzero at the jth component:
2155
-
2156
- expectedresult = np.zeros((rows), dtype=dtype)
2157
- expectedresult[i] = 2 * v.numpy()[0, j]
2158
- assert_np_equal(tape.gradients[s].numpy()[0], expectedresult, tol=10 * tol)
2159
-
2160
- expectedresult = np.zeros((cols), dtype=dtype)
2161
- expectedresult[j] = 2 * s.numpy()[0, i]
2162
- assert_np_equal(tape.gradients[v].numpy()[0], expectedresult, tol=10 * tol)
2163
- tape.zero()
2164
-
2165
- idx = idx + 1
2166
-
2167
-
2168
- def test_scalar_division(test, device, dtype, register_kernels=False):
2169
- rng = np.random.default_rng(123)
2170
-
2171
- tol = {
2172
- np.float16: 1.0e-2,
2173
- np.float32: 1.0e-6,
2174
- np.float64: 1.0e-8,
2175
- }.get(dtype, 0)
2176
-
2177
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2178
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2179
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2180
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2181
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2182
-
2183
- output_select_kernel = get_select_kernel(wptype)
2184
-
2185
- def check_mat_scalar_div(
2186
- s: wp.array(dtype=wptype),
2187
- m2: wp.array(dtype=mat22),
2188
- m3: wp.array(dtype=mat33),
2189
- m4: wp.array(dtype=mat44),
2190
- m5: wp.array(dtype=mat55),
2191
- outcomponents: wp.array(dtype=wptype),
2192
- ):
2193
- m2result = m2[0] / s[0]
2194
- m3result = m3[0] / s[0]
2195
- m4result = m4[0] / s[0]
2196
- m5result = m5[0] / s[0]
2197
-
2198
- # multiply outputs by 2 so we've got something to backpropagate:
2199
- idx = 0
2200
- for i in range(2):
2201
- for j in range(2):
2202
- outcomponents[idx] = wptype(2) * m2result[i, j]
2203
- idx = idx + 1
2204
-
2205
- for i in range(3):
2206
- for j in range(3):
2207
- outcomponents[idx] = wptype(2) * m3result[i, j]
2208
- idx = idx + 1
2209
-
2210
- for i in range(4):
2211
- for j in range(4):
2212
- outcomponents[idx] = wptype(2) * m4result[i, j]
2213
- idx = idx + 1
2214
-
2215
- for i in range(5):
2216
- for j in range(5):
2217
- outcomponents[idx] = wptype(2) * m5result[i, j]
2218
- idx = idx + 1
2219
-
2220
- kernel = getkernel(check_mat_scalar_div, suffix=dtype.__name__)
2221
-
2222
- if register_kernels:
2223
- return
2224
-
2225
- s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
2226
- m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2227
- m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2228
- m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2229
- m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2230
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
2231
-
2232
- wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents], device=device)
2233
-
2234
- sval = s.numpy()[0]
2235
- if dtype in np_float_types:
2236
- assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy().reshape(-1) / sval, tol=tol)
2237
- assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy().reshape(-1) / sval, tol=10 * tol)
2238
- assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy().reshape(-1) / sval, tol=10 * tol)
2239
- assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy().reshape(-1) / sval, tol=10 * tol)
2240
- else:
2241
- assert_np_equal(outcomponents.numpy()[:4], 2 * (m2.numpy().reshape(-1) // sval), tol=tol)
2242
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (m3.numpy().reshape(-1) // sval), tol=10 * tol)
2243
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (m4.numpy().reshape(-1) // sval), tol=10 * tol)
2244
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (m5.numpy().reshape(-1) // sval), tol=10 * tol)
2245
-
2246
- if dtype in np_float_types:
2247
- idx = 0
2248
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2249
- for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
2250
- for i in range(dim):
2251
- for j in range(dim):
2252
- tape = wp.Tape()
2253
- with tape:
2254
- wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents], device=device)
2255
- wp.launch(
2256
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2257
- )
2258
- tape.backward(loss=out)
2259
- expectedresult = np.zeros((dim, dim), dtype=dtype)
2260
- expectedresult[i, j] = 2.0 / sval
2261
- assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
2262
- assert_np_equal(
2263
- tape.gradients[s].numpy()[0], -2 * input.numpy()[0, i, j] / (sval * sval), tol=10 * tol
2264
- )
2265
- tape.zero()
2266
-
2267
- idx = idx + 1
2268
-
2269
-
2270
- def test_addition(test, device, dtype, register_kernels=False):
2271
- rng = np.random.default_rng(123)
2272
-
2273
- tol = {
2274
- np.float16: 2.0e-2,
2275
- np.float32: 5.0e-6,
2276
- np.float64: 1.0e-8,
2277
- }.get(dtype, 0)
2278
-
2279
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2280
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2281
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2282
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2283
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2284
-
2285
- output_select_kernel = get_select_kernel(wptype)
2286
-
2287
- def check_mat_add(
2288
- s2: wp.array(dtype=mat22),
2289
- s3: wp.array(dtype=mat33),
2290
- s4: wp.array(dtype=mat44),
2291
- s5: wp.array(dtype=mat55),
2292
- v2: wp.array(dtype=mat22),
2293
- v3: wp.array(dtype=mat33),
2294
- v4: wp.array(dtype=mat44),
2295
- v5: wp.array(dtype=mat55),
2296
- outcomponents: wp.array(dtype=wptype),
2297
- ):
2298
- v2result = v2[0] + s2[0]
2299
- v3result = v3[0] + s3[0]
2300
- v4result = v4[0] + s4[0]
2301
- v5result = v5[0] + s5[0]
2302
-
2303
- # multiply outputs by 2 so we've got something to backpropagate:
2304
- idx = 0
2305
- for i in range(2):
2306
- for j in range(2):
2307
- outcomponents[idx] = wptype(2) * v2result[i, j]
2308
- idx = idx + 1
2309
-
2310
- for i in range(3):
2311
- for j in range(3):
2312
- outcomponents[idx] = wptype(2) * v3result[i, j]
2313
- idx = idx + 1
2314
-
2315
- for i in range(4):
2316
- for j in range(4):
2317
- outcomponents[idx] = wptype(2) * v4result[i, j]
2318
- idx = idx + 1
2319
-
2320
- for i in range(5):
2321
- for j in range(5):
2322
- outcomponents[idx] = wptype(2) * v5result[i, j]
2323
- idx = idx + 1
2324
-
2325
- kernel = getkernel(check_mat_add, suffix=dtype.__name__)
2326
-
2327
- if register_kernels:
2328
- return
2329
-
2330
- s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2331
- s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2332
- s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2333
- s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2334
- v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2335
- v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2336
- v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2337
- v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2338
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
2339
-
2340
- wp.launch(
2341
- kernel,
2342
- dim=1,
2343
- inputs=[
2344
- s2,
2345
- s3,
2346
- s4,
2347
- s5,
2348
- v2,
2349
- v3,
2350
- v4,
2351
- v5,
2352
- ],
2353
- outputs=[outcomponents],
2354
- device=device,
2355
- )
2356
-
2357
- assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() + s2.numpy()).reshape(-1), tol=tol)
2358
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() + s3.numpy()).reshape(-1), tol=tol)
2359
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() + s4.numpy()).reshape(-1), tol=tol)
2360
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() + s5.numpy()).reshape(-1), tol=tol)
2361
-
2362
- if dtype in np_float_types:
2363
- idx = 0
2364
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2365
- for dim, in1, in2 in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
2366
- for i in range(dim):
2367
- for j in range(dim):
2368
- tape = wp.Tape()
2369
- with tape:
2370
- wp.launch(
2371
- kernel,
2372
- dim=1,
2373
- inputs=[
2374
- s2,
2375
- s3,
2376
- s4,
2377
- s5,
2378
- v2,
2379
- v3,
2380
- v4,
2381
- v5,
2382
- ],
2383
- outputs=[outcomponents],
2384
- device=device,
2385
- )
2386
- wp.launch(
2387
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2388
- )
2389
- tape.backward(loss=out)
2390
- expectedresult = np.zeros((dim, dim), dtype=dtype)
2391
- expectedresult[i, j] = 2
2392
- assert_np_equal(tape.gradients[in2].numpy()[0], expectedresult, tol=10 * tol)
2393
- expectedresult[i, j] = 2
2394
- assert_np_equal(tape.gradients[in1].numpy()[0], expectedresult, tol=10 * tol)
2395
- tape.zero()
2396
-
2397
- idx = idx + 1
2398
-
2399
-
2400
- def test_subtraction(test, device, dtype, register_kernels=False):
2401
- rng = np.random.default_rng(123)
2402
-
2403
- tol = {
2404
- np.float16: 5.0e-3,
2405
- np.float32: 1.0e-6,
2406
- np.float64: 1.0e-8,
2407
- }.get(dtype, 0)
2408
-
2409
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2410
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2411
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2412
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2413
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2414
-
2415
- output_select_kernel = get_select_kernel(wptype)
2416
-
2417
- def check_mat_sub(
2418
- s2: wp.array(dtype=mat22),
2419
- s3: wp.array(dtype=mat33),
2420
- s4: wp.array(dtype=mat44),
2421
- s5: wp.array(dtype=mat55),
2422
- v2: wp.array(dtype=mat22),
2423
- v3: wp.array(dtype=mat33),
2424
- v4: wp.array(dtype=mat44),
2425
- v5: wp.array(dtype=mat55),
2426
- outcomponents: wp.array(dtype=wptype),
2427
- ):
2428
- v2result = v2[0] - s2[0]
2429
- v3result = v3[0] - s3[0]
2430
- v4result = v4[0] - s4[0]
2431
- v5result = v5[0] - s5[0]
2432
-
2433
- # multiply outputs by 2 so we've got something to backpropagate:
2434
- idx = 0
2435
- for i in range(2):
2436
- for j in range(2):
2437
- outcomponents[idx] = wptype(2) * v2result[i, j]
2438
- idx = idx + 1
2439
-
2440
- for i in range(3):
2441
- for j in range(3):
2442
- outcomponents[idx] = wptype(2) * v3result[i, j]
2443
- idx = idx + 1
2444
-
2445
- for i in range(4):
2446
- for j in range(4):
2447
- outcomponents[idx] = wptype(2) * v4result[i, j]
2448
- idx = idx + 1
2449
-
2450
- for i in range(5):
2451
- for j in range(5):
2452
- outcomponents[idx] = wptype(2) * v5result[i, j]
2453
- idx = idx + 1
2454
-
2455
- kernel = getkernel(check_mat_sub, suffix=dtype.__name__)
2456
-
2457
- if register_kernels:
2458
- return
2459
-
2460
- s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2461
- s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2462
- s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2463
- s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2464
- v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2465
- v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2466
- v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2467
- v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2468
- outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
2469
-
2470
- wp.launch(
2471
- kernel,
2472
- dim=1,
2473
- inputs=[
2474
- s2,
2475
- s3,
2476
- s4,
2477
- s5,
2478
- v2,
2479
- v3,
2480
- v4,
2481
- v5,
2482
- ],
2483
- outputs=[outcomponents],
2484
- device=device,
2485
- )
2486
-
2487
- assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() - s2.numpy()).reshape(-1), tol=tol)
2488
- assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() - s3.numpy()).reshape(-1), tol=tol)
2489
- assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() - s4.numpy()).reshape(-1), tol=tol)
2490
- assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() - s5.numpy()).reshape(-1), tol=10 * tol)
2491
-
2492
- if dtype in np_float_types:
2493
- idx = 0
2494
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2495
- for dim, in1, in2 in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
2496
- for i in range(dim):
2497
- for j in range(dim):
2498
- tape = wp.Tape()
2499
- with tape:
2500
- wp.launch(
2501
- kernel,
2502
- dim=1,
2503
- inputs=[
2504
- s2,
2505
- s3,
2506
- s4,
2507
- s5,
2508
- v2,
2509
- v3,
2510
- v4,
2511
- v5,
2512
- ],
2513
- outputs=[outcomponents],
2514
- device=device,
2515
- )
2516
- wp.launch(
2517
- output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2518
- )
2519
- tape.backward(loss=out)
2520
- expectedresult = np.zeros((dim, dim), dtype=dtype)
2521
- expectedresult[i, j] = 2
2522
- assert_np_equal(tape.gradients[in2].numpy()[0], expectedresult, tol=10 * tol)
2523
- expectedresult[i, j] = -2
2524
- assert_np_equal(tape.gradients[in1].numpy()[0], expectedresult, tol=10 * tol)
2525
- tape.zero()
2526
-
2527
- idx = idx + 1
2528
-
2529
-
2530
- def test_ddot(test, device, dtype, register_kernels=False):
2531
- rng = np.random.default_rng(123)
2532
-
2533
- tol = {
2534
- np.float16: 5.0e-3,
2535
- np.float32: 1.0e-6,
2536
- np.float64: 1.0e-8,
2537
- }.get(dtype, 0)
2538
-
2539
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2540
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2541
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2542
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2543
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2544
-
2545
- def check_mat_dot(
2546
- s2: wp.array(dtype=mat22),
2547
- s3: wp.array(dtype=mat33),
2548
- s4: wp.array(dtype=mat44),
2549
- s5: wp.array(dtype=mat55),
2550
- v2: wp.array(dtype=mat22),
2551
- v3: wp.array(dtype=mat33),
2552
- v4: wp.array(dtype=mat44),
2553
- v5: wp.array(dtype=mat55),
2554
- dot2: wp.array(dtype=wptype),
2555
- dot3: wp.array(dtype=wptype),
2556
- dot4: wp.array(dtype=wptype),
2557
- dot5: wp.array(dtype=wptype),
2558
- ):
2559
- # multiply outputs by 2 so we've got something to backpropagate:
2560
- dot2[0] = wptype(2) * wp.ddot(v2[0], s2[0])
2561
- dot3[0] = wptype(2) * wp.ddot(v3[0], s3[0])
2562
- dot4[0] = wptype(2) * wp.ddot(v4[0], s4[0])
2563
- dot5[0] = wptype(2) * wp.ddot(v5[0], s5[0])
2564
-
2565
- kernel = getkernel(check_mat_dot, suffix=dtype.__name__)
2566
-
2567
- if register_kernels:
2568
- return
2569
-
2570
- s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2571
- s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2572
- s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2573
- s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2574
- v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2575
- v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2576
- v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2577
- v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2578
- dot2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2579
- dot3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2580
- dot4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2581
- dot5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2582
-
2583
- tape = wp.Tape()
2584
- with tape:
2585
- wp.launch(
2586
- kernel,
2587
- dim=1,
2588
- inputs=[
2589
- s2,
2590
- s3,
2591
- s4,
2592
- s5,
2593
- v2,
2594
- v3,
2595
- v4,
2596
- v5,
2597
- ],
2598
- outputs=[dot2, dot3, dot4, dot5],
2599
- device=device,
2600
- )
2601
-
2602
- assert_np_equal(dot2.numpy()[0], 2 * (v2.numpy() * s2.numpy()).sum(), tol=10 * tol)
2603
- assert_np_equal(dot3.numpy()[0], 2 * (v3.numpy() * s3.numpy()).sum(), tol=10 * tol)
2604
- assert_np_equal(dot4.numpy()[0], 2 * (v4.numpy() * s4.numpy()).sum(), tol=50 * tol)
2605
- assert_np_equal(dot5.numpy()[0], 2 * (v5.numpy() * s5.numpy()).sum(), tol=200 * tol)
2606
-
2607
- if dtype in np_float_types:
2608
- tape.backward(loss=dot2)
2609
- sgrads = tape.gradients[s2].numpy()[0]
2610
- expected_grads = 2.0 * v2.numpy()[0]
2611
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2612
-
2613
- vgrads = tape.gradients[v2].numpy()[0]
2614
- expected_grads = 2.0 * s2.numpy()[0]
2615
- assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2616
-
2617
- tape.zero()
2618
-
2619
- tape.backward(loss=dot3)
2620
- sgrads = tape.gradients[s3].numpy()[0]
2621
- expected_grads = 2.0 * v3.numpy()[0]
2622
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2623
-
2624
- vgrads = tape.gradients[v3].numpy()[0]
2625
- expected_grads = 2.0 * s3.numpy()[0]
2626
- assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2627
-
2628
- tape.zero()
2629
-
2630
- tape.backward(loss=dot4)
2631
- sgrads = tape.gradients[s4].numpy()[0]
2632
- expected_grads = 2.0 * v4.numpy()[0]
2633
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2634
-
2635
- vgrads = tape.gradients[v4].numpy()[0]
2636
- expected_grads = 2.0 * s4.numpy()[0]
2637
- assert_np_equal(vgrads, expected_grads, tol=10 * tol)
520
+ s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
521
+ s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
522
+ s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
523
+ s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
524
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
525
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
526
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
527
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
528
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
2638
529
 
2639
- tape.zero()
530
+ wp.launch(
531
+ kernel,
532
+ dim=1,
533
+ inputs=[
534
+ s2,
535
+ s3,
536
+ s4,
537
+ s5,
538
+ v2,
539
+ v3,
540
+ v4,
541
+ v5,
542
+ ],
543
+ outputs=[outcomponents],
544
+ device=device,
545
+ )
2640
546
 
2641
- tape.backward(loss=dot5)
2642
- sgrads = tape.gradients[s5].numpy()[0]
2643
- expected_grads = 2.0 * v5.numpy()[0]
2644
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
547
+ assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() - s2.numpy()).reshape(-1), tol=tol)
548
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() - s3.numpy()).reshape(-1), tol=tol)
549
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() - s4.numpy()).reshape(-1), tol=tol)
550
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() - s5.numpy()).reshape(-1), tol=10 * tol)
2645
551
 
2646
- vgrads = tape.gradients[v5].numpy()[0]
2647
- expected_grads = 2.0 * s5.numpy()[0]
2648
- assert_np_equal(vgrads, expected_grads, tol=10 * tol)
552
+ if dtype in np_float_types:
553
+ idx = 0
554
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
555
+ for dim, in1, in2 in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
556
+ for i in range(dim):
557
+ for j in range(dim):
558
+ tape = wp.Tape()
559
+ with tape:
560
+ wp.launch(
561
+ kernel,
562
+ dim=1,
563
+ inputs=[
564
+ s2,
565
+ s3,
566
+ s4,
567
+ s5,
568
+ v2,
569
+ v3,
570
+ v4,
571
+ v5,
572
+ ],
573
+ outputs=[outcomponents],
574
+ device=device,
575
+ )
576
+ wp.launch(
577
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
578
+ )
579
+ tape.backward(loss=out)
580
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
581
+ expectedresult[i, j] = 2
582
+ assert_np_equal(tape.gradients[in2].numpy()[0], expectedresult, tol=10 * tol)
583
+ expectedresult[i, j] = -2
584
+ assert_np_equal(tape.gradients[in1].numpy()[0], expectedresult, tol=10 * tol)
585
+ tape.zero()
2649
586
 
2650
- tape.zero()
587
+ idx = idx + 1
2651
588
 
2652
589
 
2653
590
  def test_determinant(test, device, dtype, register_kernels=False):
@@ -2787,271 +724,122 @@ def test_determinant(test, device, dtype, register_kernels=False):
2787
724
  outputs=[
2788
725
  det2,
2789
726
  det3,
2790
- det4,
2791
- ],
2792
- device=device,
2793
- )
2794
- dplus = det3.numpy()[0]
2795
- v3test[0, i, j] -= 2.0 * dx
2796
- wp.launch(
2797
- kernel,
2798
- dim=1,
2799
- inputs=[
2800
- v2,
2801
- wp.array(v3test, dtype=v3.dtype, requires_grad=True, device=device),
2802
- v4,
2803
- ],
2804
- outputs=[
2805
- det2,
2806
- det3,
2807
- det4,
2808
- ],
2809
- device=device,
2810
- )
2811
- dminus = det3.numpy()[0]
2812
- assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v3grads[i, j] / dplus, tol=fdtol)
2813
-
2814
- for i in range(4):
2815
- for j in range(4):
2816
- v4test = v4.numpy()
2817
- v4test[0, i, j] += dx
2818
- wp.launch(
2819
- kernel,
2820
- dim=1,
2821
- inputs=[
2822
- v2,
2823
- v3,
2824
- wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device),
2825
- ],
2826
- outputs=[
2827
- det2,
2828
- det3,
2829
- det4,
2830
- ],
2831
- device=device,
2832
- )
2833
- dplus = det4.numpy()[0]
2834
- v4test[0, i, j] -= 2.0 * dx
2835
- wp.launch(
2836
- kernel,
2837
- dim=1,
2838
- inputs=[
2839
- v2,
2840
- v3,
2841
- wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device),
2842
- ],
2843
- outputs=[
2844
- det2,
2845
- det3,
2846
- det4,
2847
- ],
2848
- device=device,
2849
- )
2850
- dminus = det4.numpy()[0]
2851
- assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v4grads[i, j] / dplus, tol=fdtol)
2852
-
2853
-
2854
- def test_trace(test, device, dtype, register_kernels=False):
2855
- rng = np.random.default_rng(123)
2856
-
2857
- tol = {
2858
- np.float16: 1.0e-3,
2859
- np.float32: 1.0e-6,
2860
- np.float64: 1.0e-8,
2861
- }.get(dtype, 0)
2862
-
2863
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2864
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2865
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2866
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2867
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2868
-
2869
- def check_mat_trace(
2870
- v2: wp.array(dtype=mat22),
2871
- v3: wp.array(dtype=mat33),
2872
- v4: wp.array(dtype=mat44),
2873
- v5: wp.array(dtype=mat55),
2874
- tr2: wp.array(dtype=wptype),
2875
- tr3: wp.array(dtype=wptype),
2876
- tr4: wp.array(dtype=wptype),
2877
- tr5: wp.array(dtype=wptype),
2878
- ):
2879
- # multiply outputs by 2 so we've got something to backpropagate:
2880
- tr2[0] = wptype(2) * wp.trace(v2[0])
2881
- tr3[0] = wptype(2) * wp.trace(v3[0])
2882
- tr4[0] = wptype(2) * wp.trace(v4[0])
2883
- tr5[0] = wptype(2) * wp.trace(v5[0])
2884
-
2885
- kernel = getkernel(check_mat_trace, suffix=dtype.__name__)
2886
-
2887
- if register_kernels:
2888
- return
2889
-
2890
- v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2891
- v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2892
- v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2893
- v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2894
- tr2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2895
- tr3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2896
- tr4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2897
- tr5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2898
-
2899
- tape = wp.Tape()
2900
- with tape:
2901
- wp.launch(
2902
- kernel,
2903
- dim=1,
2904
- inputs=[
2905
- v2,
2906
- v3,
2907
- v4,
2908
- v5,
2909
- ],
2910
- outputs=[
2911
- tr2,
2912
- tr3,
2913
- tr4,
2914
- tr5,
2915
- ],
2916
- device=device,
2917
- )
2918
-
2919
- assert_np_equal(tr2.numpy()[0], 2 * np.trace(v2.numpy()[0]), tol=10 * tol)
2920
- assert_np_equal(tr3.numpy()[0], 2 * np.trace(v3.numpy()[0]), tol=10 * tol)
2921
- assert_np_equal(tr4.numpy()[0], 2 * np.trace(v4.numpy()[0]), tol=200 * tol)
2922
- assert_np_equal(tr4.numpy()[0], 2 * np.trace(v4.numpy()[0]), tol=200 * tol)
2923
-
2924
- if dtype in np_float_types:
2925
- tape.backward(loss=tr2)
2926
- vgrads = tape.gradients[v2].numpy()[0]
2927
- assert_np_equal(vgrads, 2.0 * np.eye(2), tol=10 * tol)
2928
- tape.zero()
2929
-
2930
- tape.backward(loss=tr3)
2931
- vgrads = tape.gradients[v3].numpy()[0]
2932
- assert_np_equal(vgrads, 2.0 * np.eye(3), tol=10 * tol)
2933
- tape.zero()
2934
-
2935
- tape.backward(loss=tr4)
2936
- vgrads = tape.gradients[v4].numpy()[0]
2937
- assert_np_equal(vgrads, 2.0 * np.eye(4), tol=10 * tol)
2938
- tape.zero()
2939
-
2940
- tape.backward(loss=tr5)
2941
- vgrads = tape.gradients[v5].numpy()[0]
2942
- assert_np_equal(vgrads, 2.0 * np.eye(5), tol=10 * tol)
2943
- tape.zero()
2944
-
2945
-
2946
- def test_diag(test, device, dtype, register_kernels=False):
2947
- rng = np.random.default_rng(123)
2948
-
2949
- tol = {
2950
- np.float16: 1.0e-3,
2951
- np.float32: 1.0e-6,
2952
- np.float64: 1.0e-8,
2953
- }.get(dtype, 0)
2954
-
2955
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2956
- vec5 = wp.types.vector(length=5, dtype=wptype)
2957
-
2958
- output_select_kernel = get_select_kernel(wptype)
2959
-
2960
- def check_mat_diag(
2961
- s5: wp.array(dtype=vec5),
2962
- outcomponents: wp.array(dtype=wptype),
2963
- ):
2964
- # multiply outputs by 2 so we've got something to backpropagate:
2965
- m55result = wptype(2) * wp.diag(s5[0])
2966
-
2967
- idx = 0
2968
- for i in range(5):
2969
- for j in range(5):
2970
- outcomponents[idx] = m55result[i, j]
2971
- idx = idx + 1
2972
-
2973
- kernel = getkernel(check_mat_diag, suffix=dtype.__name__)
2974
-
2975
- if register_kernels:
2976
- return
2977
-
2978
- s5 = wp.array(randvals(rng, [1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
2979
- outcomponents = wp.zeros(5 * 5, dtype=wptype, requires_grad=True, device=device)
2980
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2981
-
2982
- wp.launch(kernel, dim=1, inputs=[s5], outputs=[outcomponents], device=device)
2983
-
2984
- assert_np_equal(outcomponents.numpy(), 2 * np.diag(s5.numpy()[0]), tol=tol)
2985
-
2986
- if dtype in np_float_types:
2987
- idx = 0
2988
- for i in range(5):
2989
- for j in range(5):
2990
- tape = wp.Tape()
2991
- with tape:
2992
- wp.launch(kernel, dim=1, inputs=[s5], outputs=[outcomponents], device=device)
2993
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
2994
- tape.backward(loss=out)
2995
- expectedresult = np.zeros(5, dtype=dtype)
2996
- if i == j:
2997
- expectedresult[i] = 2
2998
- assert_np_equal(tape.gradients[s5].numpy()[0], expectedresult, tol=10 * tol)
2999
- tape.zero()
3000
-
3001
- idx = idx + 1
3002
-
3003
-
3004
- def test_get_diag(test, device, dtype, register_kernels=False):
3005
- tol = {
3006
- np.float16: 1.0e-3,
3007
- np.float32: 1.0e-6,
3008
- np.float64: 1.0e-8,
3009
- }.get(dtype, 0)
3010
-
3011
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
3012
- mat55 = wp.types.vector(shape=(5, 5), dtype=wptype)
3013
-
3014
- output_select_kernel = get_select_kernel(wptype)
3015
-
3016
- def check_mat_diag(
3017
- m55: wp.array(dtype=mat55),
3018
- outcomponents: wp.array(dtype=wptype),
3019
- ):
3020
- # multiply outputs by 2 so we've got something to backpropagate:
3021
- vec5result = wptype(2) * wp.get_diag(m55[0])
3022
-
3023
- idx = 0
3024
- for i in range(5):
3025
- outcomponents[idx] = vec5result[i]
3026
- idx = idx + 1
3027
-
3028
- kernel = getkernel(check_mat_diag, suffix=dtype.__name__)
3029
-
3030
- if register_kernels:
3031
- return
3032
-
3033
- m55 = wp.array(randvals((1, 5, 5), dtype), dtype=mat55, requires_grad=True, device=device)
3034
- outcomponents = wp.zeros(5, dtype=wptype, requires_grad=True, device=device)
3035
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
3036
-
3037
- wp.launch(kernel, dim=1, inputs=[m55], outputs=[outcomponents], device=device)
727
+ det4,
728
+ ],
729
+ device=device,
730
+ )
731
+ dplus = det3.numpy()[0]
732
+ v3test[0, i, j] -= 2.0 * dx
733
+ wp.launch(
734
+ kernel,
735
+ dim=1,
736
+ inputs=[
737
+ v2,
738
+ wp.array(v3test, dtype=v3.dtype, requires_grad=True, device=device),
739
+ v4,
740
+ ],
741
+ outputs=[
742
+ det2,
743
+ det3,
744
+ det4,
745
+ ],
746
+ device=device,
747
+ )
748
+ dminus = det3.numpy()[0]
749
+ assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v3grads[i, j] / dplus, tol=fdtol)
3038
750
 
3039
- assert_np_equal(outcomponents.numpy(), 2 * np.diag(m55.numpy()[0]), tol=tol)
751
+ for i in range(4):
752
+ for j in range(4):
753
+ v4test = v4.numpy()
754
+ v4test[0, i, j] += dx
755
+ wp.launch(
756
+ kernel,
757
+ dim=1,
758
+ inputs=[
759
+ v2,
760
+ v3,
761
+ wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device),
762
+ ],
763
+ outputs=[
764
+ det2,
765
+ det3,
766
+ det4,
767
+ ],
768
+ device=device,
769
+ )
770
+ dplus = det4.numpy()[0]
771
+ v4test[0, i, j] -= 2.0 * dx
772
+ wp.launch(
773
+ kernel,
774
+ dim=1,
775
+ inputs=[
776
+ v2,
777
+ v3,
778
+ wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device),
779
+ ],
780
+ outputs=[
781
+ det2,
782
+ det3,
783
+ det4,
784
+ ],
785
+ device=device,
786
+ )
787
+ dminus = det4.numpy()[0]
788
+ assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v4grads[i, j] / dplus, tol=fdtol)
3040
789
 
3041
- if dtype in np_float_types:
3042
- idx = 0
3043
- for i in range(5):
3044
- tape = wp.Tape()
3045
- with tape:
3046
- wp.launch(kernel, dim=1, inputs=[m55], outputs=[outcomponents], device=device)
3047
- wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
3048
- tape.backward(loss=out)
3049
- expectedresult = np.zeros((5, 5), dtype=dtype)
3050
- expectedresult[i, i] = 2
3051
- assert_np_equal(tape.gradients[m55].numpy()[0], expectedresult, tol=10 * tol)
3052
- tape.zero()
3053
790
 
3054
- idx = idx + 1
791
+ # Unused. Why?
792
+ # def test_get_diag(test, device, dtype, register_kernels=False):
793
+ # tol = {
794
+ # np.float16: 1.0e-3,
795
+ # np.float32: 1.0e-6,
796
+ # np.float64: 1.0e-8,
797
+ # }.get(dtype, 0)
798
+ #
799
+ # wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
800
+ # mat55 = wp.types.vector(shape=(5, 5), dtype=wptype)
801
+ #
802
+ # output_select_kernel = get_select_kernel(wptype)
803
+ #
804
+ # def check_mat_diag(
805
+ # m55: wp.array(dtype=mat55),
806
+ # outcomponents: wp.array(dtype=wptype),
807
+ # ):
808
+ # # multiply outputs by 2 so we've got something to backpropagate:
809
+ # vec5result = wptype(2) * wp.get_diag(m55[0])
810
+ #
811
+ # idx = 0
812
+ # for i in range(5):
813
+ # outcomponents[idx] = vec5result[i]
814
+ # idx = idx + 1
815
+ #
816
+ # kernel = getkernel(check_mat_diag, suffix=dtype.__name__)
817
+ #
818
+ # if register_kernels:
819
+ # return
820
+ #
821
+ # m55 = wp.array(randvals((1, 5, 5), dtype), dtype=mat55, requires_grad=True, device=device)
822
+ # outcomponents = wp.zeros(5, dtype=wptype, requires_grad=True, device=device)
823
+ # out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
824
+ #
825
+ # wp.launch(kernel, dim=1, inputs=[m55], outputs=[outcomponents], device=device)
826
+ #
827
+ # assert_np_equal(outcomponents.numpy(), 2 * np.diag(m55.numpy()[0]), tol=tol)
828
+ #
829
+ # if dtype in np_float_types:
830
+ # idx = 0
831
+ # for i in range(5):
832
+ # tape = wp.Tape()
833
+ # with tape:
834
+ # wp.launch(kernel, dim=1, inputs=[m55], outputs=[outcomponents], device=device)
835
+ # wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
836
+ # tape.backward(loss=out)
837
+ # expectedresult = np.zeros((5, 5), dtype=dtype)
838
+ # expectedresult[i, i] = 2
839
+ # assert_np_equal(tape.gradients[m55].numpy()[0], expectedresult, tol=10 * tol)
840
+ # tape.zero()
841
+ #
842
+ # idx = idx + 1
3055
843
 
3056
844
 
3057
845
  def test_inverse(test, device, dtype, register_kernels=False):
@@ -3790,330 +1578,6 @@ def test_transform_vector(test, device, dtype, register_kernels=False):
3790
1578
  tape.zero()
3791
1579
 
3792
1580
 
3793
- def test_anon_type_instance(test, device, dtype, register_kernels=False):
3794
- rng = np.random.default_rng(123)
3795
-
3796
- tol = {
3797
- np.float16: 5.0e-3,
3798
- np.float32: 1.0e-6,
3799
- np.float64: 1.0e-8,
3800
- }.get(dtype, 0)
3801
-
3802
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
3803
-
3804
- def check_scalar_init(
3805
- input: wp.array(dtype=wptype),
3806
- output: wp.array(dtype=wptype),
3807
- ):
3808
- m2result = wp.matrix(input[0], shape=(2, 2))
3809
- m3result = wp.matrix(input[1], shape=(3, 3))
3810
- m4result = wp.matrix(input[2], shape=(4, 4))
3811
- m5result = wp.matrix(input[3], shape=(5, 5))
3812
- m32result = wp.matrix(input[4], shape=(3, 2))
3813
-
3814
- idx = 0
3815
- for i in range(2):
3816
- for j in range(2):
3817
- output[idx] = wptype(2) * m2result[i, j]
3818
- idx = idx + 1
3819
- for i in range(3):
3820
- for j in range(3):
3821
- output[idx] = wptype(2) * m3result[i, j]
3822
- idx = idx + 1
3823
- for i in range(4):
3824
- for j in range(4):
3825
- output[idx] = wptype(2) * m4result[i, j]
3826
- idx = idx + 1
3827
- for i in range(5):
3828
- for j in range(5):
3829
- output[idx] = wptype(2) * m5result[i, j]
3830
- idx = idx + 1
3831
- for i in range(3):
3832
- for j in range(2):
3833
- output[idx] = wptype(2) * m32result[i, j]
3834
- idx = idx + 1
3835
-
3836
- def check_component_init(
3837
- input: wp.array(dtype=wptype),
3838
- output: wp.array(dtype=wptype),
3839
- ):
3840
- m2result = wp.matrix(input[0], input[1], input[2], input[3], shape=(2, 2))
3841
- m3result = wp.matrix(
3842
- input[4], input[5], input[6], input[7], input[8], input[9], input[10], input[11], input[12], shape=(3, 3)
3843
- )
3844
- m4result = wp.matrix(
3845
- input[13],
3846
- input[14],
3847
- input[15],
3848
- input[16],
3849
- input[17],
3850
- input[18],
3851
- input[19],
3852
- input[20],
3853
- input[21],
3854
- input[22],
3855
- input[23],
3856
- input[24],
3857
- input[25],
3858
- input[26],
3859
- input[27],
3860
- input[28],
3861
- shape=(4, 4),
3862
- )
3863
- m5result = wp.matrix(
3864
- input[29],
3865
- input[30],
3866
- input[31],
3867
- input[32],
3868
- input[33],
3869
- input[34],
3870
- input[35],
3871
- input[36],
3872
- input[37],
3873
- input[38],
3874
- input[39],
3875
- input[40],
3876
- input[41],
3877
- input[42],
3878
- input[43],
3879
- input[44],
3880
- input[45],
3881
- input[46],
3882
- input[47],
3883
- input[48],
3884
- input[49],
3885
- input[50],
3886
- input[51],
3887
- input[52],
3888
- input[53],
3889
- shape=(5, 5),
3890
- )
3891
- m32result = wp.matrix(input[54], input[55], input[56], input[57], input[58], input[59], shape=(3, 2))
3892
-
3893
- idx = 0
3894
- for i in range(2):
3895
- for j in range(2):
3896
- output[idx] = wptype(2) * m2result[i, j]
3897
- idx = idx + 1
3898
- for i in range(3):
3899
- for j in range(3):
3900
- output[idx] = wptype(2) * m3result[i, j]
3901
- idx = idx + 1
3902
- for i in range(4):
3903
- for j in range(4):
3904
- output[idx] = wptype(2) * m4result[i, j]
3905
- idx = idx + 1
3906
- for i in range(5):
3907
- for j in range(5):
3908
- output[idx] = wptype(2) * m5result[i, j]
3909
- idx = idx + 1
3910
- for i in range(3):
3911
- for j in range(2):
3912
- output[idx] = wptype(2) * m32result[i, j]
3913
- idx = idx + 1
3914
-
3915
- scalar_kernel = getkernel(check_scalar_init, suffix=dtype.__name__)
3916
- component_kernel = getkernel(check_component_init, suffix=dtype.__name__)
3917
- output_select_kernel = get_select_kernel(wptype)
3918
-
3919
- if register_kernels:
3920
- return
3921
-
3922
- input = wp.array(randvals(rng, [5], dtype), requires_grad=True, device=device)
3923
- output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2, dtype=wptype, requires_grad=True, device=device)
3924
-
3925
- wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
3926
-
3927
- assert_np_equal(output.numpy()[:4], 2 * np.array([input.numpy()[0]] * 2 * 2), tol=1.0e-6)
3928
- assert_np_equal(output.numpy()[4:13], 2 * np.array([input.numpy()[1]] * 3 * 3), tol=1.0e-6)
3929
- assert_np_equal(output.numpy()[13:29], 2 * np.array([input.numpy()[2]] * 4 * 4), tol=1.0e-6)
3930
- assert_np_equal(output.numpy()[29:54], 2 * np.array([input.numpy()[3]] * 5 * 5), tol=1.0e-6)
3931
- assert_np_equal(output.numpy()[54:], 2 * np.array([input.numpy()[4]] * 3 * 2), tol=1.0e-6)
3932
-
3933
- if dtype in np_float_types:
3934
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
3935
- for i in range(len(output)):
3936
- tape = wp.Tape()
3937
- with tape:
3938
- wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
3939
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
3940
-
3941
- tape.backward(loss=out)
3942
- expected = np.zeros_like(input.numpy())
3943
- if i < 4:
3944
- expected[0] = 2
3945
- elif i < 13:
3946
- expected[1] = 2
3947
- elif i < 29:
3948
- expected[2] = 2
3949
- elif i < 54:
3950
- expected[3] = 2
3951
- else:
3952
- expected[4] = 2
3953
-
3954
- assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
3955
-
3956
- tape.reset()
3957
- tape.zero()
3958
-
3959
- input = wp.array(randvals(rng, [2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2], dtype), requires_grad=True, device=device)
3960
- output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2, dtype=wptype, requires_grad=True, device=device)
3961
-
3962
- wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
3963
-
3964
- assert_np_equal(output.numpy(), 2 * input.numpy(), tol=1.0e-6)
3965
-
3966
- if dtype in np_float_types:
3967
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
3968
- for i in range(len(output)):
3969
- tape = wp.Tape()
3970
- with tape:
3971
- wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
3972
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
3973
-
3974
- tape.backward(loss=out)
3975
- expected = np.zeros_like(input.numpy())
3976
- expected[i] = 2
3977
-
3978
- assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
3979
-
3980
- tape.reset()
3981
- tape.zero()
3982
-
3983
-
3984
- def test_identity(test, device, dtype, register_kernels=False):
3985
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
3986
-
3987
- def check_identity_mat(
3988
- output: wp.array(dtype=wptype),
3989
- ):
3990
- m2result = wp.identity(dtype=wptype, n=2)
3991
- m3result = wp.identity(dtype=wptype, n=3)
3992
- m4result = wp.identity(dtype=wptype, n=4)
3993
- m5result = wp.identity(dtype=wptype, n=5)
3994
-
3995
- idx = 0
3996
- for i in range(2):
3997
- for j in range(2):
3998
- output[idx] = wptype(2) * m2result[i, j]
3999
- idx = idx + 1
4000
- for i in range(3):
4001
- for j in range(3):
4002
- output[idx] = wptype(2) * m3result[i, j]
4003
- idx = idx + 1
4004
- for i in range(4):
4005
- for j in range(4):
4006
- output[idx] = wptype(2) * m4result[i, j]
4007
- idx = idx + 1
4008
- for i in range(5):
4009
- for j in range(5):
4010
- output[idx] = wptype(2) * m5result[i, j]
4011
- idx = idx + 1
4012
-
4013
- id_kernel = getkernel(check_identity_mat, suffix=dtype.__name__)
4014
-
4015
- if register_kernels:
4016
- return
4017
-
4018
- output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
4019
- wp.launch(id_kernel, dim=1, inputs=[], outputs=[output], device=device)
4020
- assert_np_equal(output.numpy()[:4], 2 * np.eye(2), tol=1.0e-6)
4021
- assert_np_equal(output.numpy()[4:13], 2 * np.eye(3), tol=1.0e-6)
4022
- assert_np_equal(output.numpy()[13:29], 2 * np.eye(4), tol=1.0e-6)
4023
- assert_np_equal(output.numpy()[29:], 2 * np.eye(5), tol=1.0e-6)
4024
-
4025
-
4026
- def test_equivalent_types(test, device, dtype, register_kernels=False):
4027
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
4028
-
4029
- # matrix types
4030
- mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
4031
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
4032
- mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
4033
- mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
4034
-
4035
- # matrix types equivalent to the above
4036
- mat22_equiv = wp.types.matrix(shape=(2, 2), dtype=wptype)
4037
- mat33_equiv = wp.types.matrix(shape=(3, 3), dtype=wptype)
4038
- mat44_equiv = wp.types.matrix(shape=(4, 4), dtype=wptype)
4039
- mat55_equiv = wp.types.matrix(shape=(5, 5), dtype=wptype)
4040
-
4041
- # declare kernel with original types
4042
- def check_equivalence(
4043
- m2: mat22,
4044
- m3: mat33,
4045
- m4: mat44,
4046
- m5: mat55,
4047
- ):
4048
- wp.expect_eq(m2, mat22(wptype(42)))
4049
- wp.expect_eq(m3, mat33(wptype(43)))
4050
- wp.expect_eq(m4, mat44(wptype(44)))
4051
- wp.expect_eq(m5, mat55(wptype(45)))
4052
-
4053
- wp.expect_eq(m2, mat22_equiv(wptype(42)))
4054
- wp.expect_eq(m3, mat33_equiv(wptype(43)))
4055
- wp.expect_eq(m4, mat44_equiv(wptype(44)))
4056
- wp.expect_eq(m5, mat55_equiv(wptype(45)))
4057
-
4058
- kernel = getkernel(check_equivalence, suffix=dtype.__name__)
4059
-
4060
- if register_kernels:
4061
- return
4062
-
4063
- # call kernel with equivalent types
4064
- m2 = mat22_equiv(42)
4065
- m3 = mat33_equiv(43)
4066
- m4 = mat44_equiv(44)
4067
- m5 = mat55_equiv(45)
4068
-
4069
- wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], device=device)
4070
-
4071
-
4072
- def test_conversions(test, device, dtype, register_kernels=False):
4073
- def check_matrices_equal(
4074
- m0: wp.mat22,
4075
- m1: wp.mat22,
4076
- m2: wp.mat22,
4077
- m3: wp.mat22,
4078
- m4: wp.mat22,
4079
- m5: wp.mat22,
4080
- m6: wp.mat22,
4081
- ):
4082
- wp.expect_eq(m1, m0)
4083
- wp.expect_eq(m2, m0)
4084
- wp.expect_eq(m3, m0)
4085
- wp.expect_eq(m4, m0)
4086
- wp.expect_eq(m5, m0)
4087
- wp.expect_eq(m6, m0)
4088
-
4089
- kernel = getkernel(check_matrices_equal, suffix=dtype.__name__)
4090
-
4091
- if register_kernels:
4092
- return
4093
-
4094
- m0 = wp.mat22(1, 2, 3, 4)
4095
-
4096
- # test explicit conversions - constructing matrices from different containers
4097
- m1 = wp.mat22(((1, 2), (3, 4))) # nested tuples
4098
- m2 = wp.mat22([[1, 2], [3, 4]]) # nested lists
4099
- m3 = wp.mat22(np.array([[1, 2], [3, 4]], dtype=dtype)) # 2d array
4100
- m4 = wp.mat22((1, 2, 3, 4)) # flat tuple
4101
- m5 = wp.mat22([1, 2, 3, 4]) # flat list
4102
- m6 = wp.mat22(np.array([1, 2, 3, 4], dtype=dtype)) # 1d array
4103
-
4104
- wp.launch(kernel, dim=1, inputs=[m0, m1, m2, m3, m4, m5, m6], device=device)
4105
-
4106
- # test implicit conversions - passing different containers as matrices to wp.launch()
4107
- m1 = ((1, 2), (3, 4)) # nested tuples
4108
- m2 = [[1, 2], [3, 4]] # nested lists
4109
- m3 = np.array([[1, 2], [3, 4]], dtype=dtype) # 2d array
4110
- m4 = (1, 2, 3, 4) # flat tuple
4111
- m5 = [1, 2, 3, 4] # flat list
4112
- m6 = np.array([1, 2, 3, 4], dtype=dtype) # 1d array
4113
-
4114
- wp.launch(kernel, dim=1, inputs=[m0, m1, m2, m3, m4, m5, m6], device=device)
4115
-
4116
-
4117
1581
  # Test matrix constructors using explicit type (float16)
4118
1582
  # note that these tests are specifically not using generics / closure
4119
1583
  # args to create kernels dynamically (like the rest of this file)
@@ -4215,229 +1679,149 @@ def test_constructors_constant_shape():
4215
1679
  m[i, j] = float(i * j)
4216
1680
 
4217
1681
 
4218
- def register(parent):
4219
- devices = get_test_devices()
4220
-
4221
- class TestMat(parent):
4222
- pass
4223
-
4224
- add_kernel_test(TestMat, test_constructors_explicit_precision, dim=1, devices=devices)
4225
- add_kernel_test(TestMat, test_constructors_default_precision, dim=1, devices=devices)
4226
- add_kernel_test(TestMat, test_constructors_constant_shape, dim=1, devices=devices)
4227
- add_kernel_test(TestMat, test_matrix_constructor_value_func, dim=1, devices=devices)
4228
-
4229
- mat103 = wp.types.matrix(shape=(10, 3), dtype=float)
4230
- add_kernel_test(
4231
- TestMat,
4232
- test_matrix_mutation,
4233
- dim=1,
4234
- inputs=[
4235
- mat103(
4236
- 1.0,
4237
- 2.0,
4238
- 3.0,
4239
- 2.0,
4240
- 4.0,
4241
- 6.0,
4242
- 3.0,
4243
- 6.0,
4244
- 9.0,
4245
- 4.0,
4246
- 8.0,
4247
- 12.0,
4248
- 5.0,
4249
- 10.0,
4250
- 15.0,
4251
- 6.0,
4252
- 12.0,
4253
- 18.0,
4254
- 7.0,
4255
- 14.0,
4256
- 21.0,
4257
- 8.0,
4258
- 16.0,
4259
- 24.0,
4260
- 9.0,
4261
- 18.0,
4262
- 27.0,
4263
- 10.0,
4264
- 20.0,
4265
- 30.0,
4266
- )
4267
- ],
4268
- devices=devices,
4269
- )
4270
-
4271
- for dtype in np_signed_int_types + np_float_types:
4272
- add_function_test_register_kernel(
4273
- TestMat, f"test_negation_{dtype.__name__}", test_negation, devices=devices, dtype=dtype
4274
- )
4275
- add_function_test_register_kernel(
4276
- TestMat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
1682
+ devices = get_test_devices()
1683
+
1684
+
1685
+ class TestMat(unittest.TestCase):
1686
+ pass
1687
+
1688
+
1689
+ add_kernel_test(TestMat, test_constructors_explicit_precision, dim=1, devices=devices)
1690
+ add_kernel_test(TestMat, test_constructors_default_precision, dim=1, devices=devices)
1691
+ add_kernel_test(TestMat, test_constructors_constant_shape, dim=1, devices=devices)
1692
+ add_kernel_test(TestMat, test_matrix_constructor_value_func, dim=1, devices=devices)
1693
+
1694
+ mat103 = wp.types.matrix(shape=(10, 3), dtype=float)
1695
+ add_kernel_test(
1696
+ TestMat,
1697
+ test_matrix_mutation,
1698
+ dim=1,
1699
+ inputs=[
1700
+ mat103(
1701
+ 1.0,
1702
+ 2.0,
1703
+ 3.0,
1704
+ 2.0,
1705
+ 4.0,
1706
+ 6.0,
1707
+ 3.0,
1708
+ 6.0,
1709
+ 9.0,
1710
+ 4.0,
1711
+ 8.0,
1712
+ 12.0,
1713
+ 5.0,
1714
+ 10.0,
1715
+ 15.0,
1716
+ 6.0,
1717
+ 12.0,
1718
+ 18.0,
1719
+ 7.0,
1720
+ 14.0,
1721
+ 21.0,
1722
+ 8.0,
1723
+ 16.0,
1724
+ 24.0,
1725
+ 9.0,
1726
+ 18.0,
1727
+ 27.0,
1728
+ 10.0,
1729
+ 20.0,
1730
+ 30.0,
4277
1731
  )
1732
+ ],
1733
+ devices=devices,
1734
+ )
4278
1735
 
4279
- add_function_test(
4280
- TestMat,
4281
- "test_anon_constructor_error_shape_keyword_missing",
4282
- test_anon_constructor_error_shape_keyword_missing,
4283
- devices=devices,
4284
- )
4285
- add_function_test(
4286
- TestMat,
4287
- "test_anon_constructor_error_dtype_keyword_missing",
4288
- test_anon_constructor_error_dtype_keyword_missing,
4289
- devices=devices,
1736
+ for dtype in np_signed_int_types + np_float_types:
1737
+ add_function_test_register_kernel(
1738
+ TestMat, f"test_negation_{dtype.__name__}", test_negation, devices=devices, dtype=dtype
4290
1739
  )
4291
- add_function_test(
4292
- TestMat,
4293
- "test_anon_constructor_error_shape_mismatch",
4294
- test_anon_constructor_error_shape_mismatch,
4295
- devices=devices,
1740
+ add_function_test_register_kernel(
1741
+ TestMat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
4296
1742
  )
1743
+
1744
+ add_function_test(
1745
+ TestMat,
1746
+ "test_anon_constructor_error_shape_keyword_missing",
1747
+ test_anon_constructor_error_shape_keyword_missing,
1748
+ devices=devices,
1749
+ )
1750
+ add_function_test(
1751
+ TestMat,
1752
+ "test_anon_constructor_error_dtype_keyword_missing",
1753
+ test_anon_constructor_error_dtype_keyword_missing,
1754
+ devices=devices,
1755
+ )
1756
+ add_function_test(
1757
+ TestMat,
1758
+ "test_anon_constructor_error_shape_mismatch",
1759
+ test_anon_constructor_error_shape_mismatch,
1760
+ devices=devices,
1761
+ )
1762
+ add_function_test(
1763
+ TestMat,
1764
+ "test_anon_constructor_error_invalid_arg_count",
1765
+ test_anon_constructor_error_invalid_arg_count,
1766
+ devices=devices,
1767
+ )
1768
+ add_function_test(
1769
+ TestMat,
1770
+ "test_tpl_constructor_error_incompatible_sizes",
1771
+ test_tpl_constructor_error_incompatible_sizes,
1772
+ devices=devices,
1773
+ )
1774
+ add_function_test(
1775
+ TestMat,
1776
+ "test_tpl_constructor_error_invalid_scalar_type",
1777
+ test_tpl_constructor_error_invalid_scalar_type,
1778
+ devices=devices,
1779
+ )
1780
+ add_function_test(
1781
+ TestMat,
1782
+ "test_tpl_constructor_error_invalid_vector_count",
1783
+ test_tpl_constructor_error_invalid_vector_count,
1784
+ devices=devices,
1785
+ )
1786
+ add_function_test(
1787
+ TestMat,
1788
+ "test_tpl_constructor_error_invalid_vector_shape",
1789
+ test_tpl_constructor_error_invalid_vector_shape,
1790
+ devices=devices,
1791
+ )
1792
+ add_function_test(
1793
+ TestMat,
1794
+ "test_tpl_constructor_error_invalid_arg_count",
1795
+ test_tpl_constructor_error_invalid_arg_count,
1796
+ devices=devices,
1797
+ )
1798
+ add_function_test(TestMat, "test_tpl_ops_with_anon", test_tpl_ops_with_anon)
1799
+
1800
+ for dtype in np_float_types:
4297
1801
  add_function_test(
4298
- TestMat,
4299
- "test_anon_constructor_error_invalid_arg_count",
4300
- test_anon_constructor_error_invalid_arg_count,
4301
- devices=devices,
1802
+ TestMat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
4302
1803
  )
4303
- add_function_test(
4304
- TestMat,
4305
- "test_tpl_constructor_error_incompatible_sizes",
4306
- test_tpl_constructor_error_incompatible_sizes,
4307
- devices=devices,
1804
+ add_function_test_register_kernel(
1805
+ TestMat, f"test_quat_constructor_{dtype.__name__}", test_quat_constructor, devices=devices, dtype=dtype
4308
1806
  )
4309
- add_function_test(
4310
- TestMat,
4311
- "test_tpl_constructor_error_invalid_scalar_type",
4312
- test_tpl_constructor_error_invalid_scalar_type,
4313
- devices=devices,
1807
+ add_function_test_register_kernel(
1808
+ TestMat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
4314
1809
  )
4315
- add_function_test(
4316
- TestMat,
4317
- "test_tpl_constructor_error_invalid_vector_count",
4318
- test_tpl_constructor_error_invalid_vector_count,
4319
- devices=devices,
1810
+ add_function_test_register_kernel(TestMat, f"test_svd_{dtype.__name__}", test_svd, devices=devices, dtype=dtype)
1811
+ add_function_test_register_kernel(TestMat, f"test_qr_{dtype.__name__}", test_qr, devices=devices, dtype=dtype)
1812
+ add_function_test_register_kernel(TestMat, f"test_eig_{dtype.__name__}", test_eig, devices=devices, dtype=dtype)
1813
+ add_function_test_register_kernel(
1814
+ TestMat, f"test_transform_point_{dtype.__name__}", test_transform_point, devices=devices, dtype=dtype
4320
1815
  )
4321
- add_function_test(
4322
- TestMat,
4323
- "test_tpl_constructor_error_invalid_vector_shape",
4324
- test_tpl_constructor_error_invalid_vector_shape,
4325
- devices=devices,
1816
+ add_function_test_register_kernel(
1817
+ TestMat, f"test_transform_vector_{dtype.__name__}", test_transform_vector, devices=devices, dtype=dtype
4326
1818
  )
4327
- add_function_test(
4328
- TestMat,
4329
- "test_tpl_constructor_error_invalid_arg_count",
4330
- test_tpl_constructor_error_invalid_arg_count,
4331
- devices=devices,
1819
+ add_function_test_register_kernel(
1820
+ TestMat, f"test_determinant_{dtype.__name__}", test_determinant, devices=devices, dtype=dtype
4332
1821
  )
4333
- add_function_test(TestMat, "test_tpl_ops_with_anon", test_tpl_ops_with_anon)
4334
-
4335
- for dtype in np_scalar_types:
4336
- add_function_test(TestMat, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
4337
- add_function_test(TestMat, f"test_components_{dtype.__name__}", test_components, devices=None, dtype=dtype)
4338
- add_function_test_register_kernel(
4339
- TestMat, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
4340
- )
4341
- add_function_test_register_kernel(
4342
- TestMat, f"test_anon_type_instance_{dtype.__name__}", test_anon_type_instance, devices=devices, dtype=dtype
4343
- )
4344
- add_function_test_register_kernel(
4345
- TestMat, f"test_identity_{dtype.__name__}", test_identity, devices=devices, dtype=dtype
4346
- )
4347
- add_function_test_register_kernel(
4348
- TestMat, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
4349
- )
4350
- add_function_test_register_kernel(
4351
- TestMat, f"test_equality_{dtype.__name__}", test_equality, devices=devices, dtype=dtype
4352
- )
4353
- add_function_test_register_kernel(
4354
- TestMat,
4355
- f"test_scalar_multiplication_{dtype.__name__}",
4356
- test_scalar_multiplication,
4357
- devices=devices,
4358
- dtype=dtype,
4359
- )
4360
- add_function_test_register_kernel(
4361
- TestMat,
4362
- f"test_matvec_multiplication_{dtype.__name__}",
4363
- test_matvec_multiplication,
4364
- devices=devices,
4365
- dtype=dtype,
4366
- )
4367
- add_function_test_register_kernel(
4368
- TestMat,
4369
- f"test_matmat_multiplication_{dtype.__name__}",
4370
- test_matmat_multiplication,
4371
- devices=devices,
4372
- dtype=dtype,
4373
- )
4374
- add_function_test_register_kernel(
4375
- TestMat, f"test_cw_multiplication_{dtype.__name__}", test_cw_multiplication, devices=devices, dtype=dtype
4376
- )
4377
- add_function_test_register_kernel(
4378
- TestMat, f"test_cw_division_{dtype.__name__}", test_cw_division, devices=devices, dtype=dtype
4379
- )
4380
- add_function_test_register_kernel(
4381
- TestMat, f"test_outer_product_{dtype.__name__}", test_outer_product, devices=devices, dtype=dtype
4382
- )
4383
- add_function_test_register_kernel(
4384
- TestMat, f"test_transpose_{dtype.__name__}", test_transpose, devices=devices, dtype=dtype
4385
- )
4386
- add_function_test_register_kernel(
4387
- TestMat, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
4388
- )
4389
- add_function_test_register_kernel(
4390
- TestMat, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
4391
- )
4392
- add_function_test_register_kernel(
4393
- TestMat, f"test_ddot_{dtype.__name__}", test_ddot, devices=devices, dtype=dtype
4394
- )
4395
- add_function_test_register_kernel(
4396
- TestMat, f"test_trace_{dtype.__name__}", test_trace, devices=devices, dtype=dtype
4397
- )
4398
- add_function_test_register_kernel(
4399
- TestMat, f"test_diag_{dtype.__name__}", test_diag, devices=devices, dtype=dtype
4400
- )
4401
- add_function_test_register_kernel(
4402
- TestMat, f"test_get_diag_{dtype.__name__}", test_diag, devices=devices, dtype=dtype
4403
- )
4404
- add_function_test_register_kernel(
4405
- TestMat, f"test_equivalent_types_{dtype.__name__}", test_equivalent_types, devices=devices, dtype=dtype
4406
- )
4407
- add_function_test_register_kernel(
4408
- TestMat, f"test_conversions_{dtype.__name__}", test_conversions, devices=devices, dtype=dtype
4409
- )
4410
- add_function_test_register_kernel(
4411
- TestMat, f"test_constants_{dtype.__name__}", test_constants, devices=devices, dtype=dtype
4412
- )
4413
-
4414
- for dtype in np_float_types:
4415
- add_function_test_register_kernel(
4416
- TestMat, f"test_quat_constructor_{dtype.__name__}", test_quat_constructor, devices=devices, dtype=dtype
4417
- )
4418
- add_function_test_register_kernel(
4419
- TestMat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
4420
- )
4421
- add_function_test_register_kernel(TestMat, f"test_svd_{dtype.__name__}", test_svd, devices=devices, dtype=dtype)
4422
- add_function_test_register_kernel(TestMat, f"test_qr_{dtype.__name__}", test_qr, devices=devices, dtype=dtype)
4423
- add_function_test_register_kernel(TestMat, f"test_eig_{dtype.__name__}", test_eig, devices=devices, dtype=dtype)
4424
- add_function_test_register_kernel(
4425
- TestMat, f"test_transform_point_{dtype.__name__}", test_transform_point, devices=devices, dtype=dtype
4426
- )
4427
- add_function_test_register_kernel(
4428
- TestMat, f"test_transform_vector_{dtype.__name__}", test_transform_vector, devices=devices, dtype=dtype
4429
- )
4430
- add_function_test_register_kernel(
4431
- TestMat, f"test_determinant_{dtype.__name__}", test_determinant, devices=devices, dtype=dtype
4432
- )
4433
- add_function_test_register_kernel(
4434
- TestMat, f"test_skew_{dtype.__name__}", test_skew, devices=devices, dtype=dtype
4435
- )
4436
-
4437
- return TestMat
1822
+ add_function_test_register_kernel(TestMat, f"test_skew_{dtype.__name__}", test_skew, devices=devices, dtype=dtype)
4438
1823
 
4439
1824
 
4440
1825
  if __name__ == "__main__":
4441
1826
  wp.build.clear_kernel_cache()
4442
- _ = register(unittest.TestCase)
4443
1827
  unittest.main(verbosity=2, failfast=True)