warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2099 @@
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
10
+ import numpy as np
11
+
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
14
+
15
+ wp.init()
16
+
17
+ np_signed_int_types = [
18
+ np.int8,
19
+ np.int16,
20
+ np.int32,
21
+ np.int64,
22
+ np.byte,
23
+ ]
24
+
25
+ np_unsigned_int_types = [
26
+ np.uint8,
27
+ np.uint16,
28
+ np.uint32,
29
+ np.uint64,
30
+ np.ubyte,
31
+ ]
32
+
33
+ np_int_types = np_signed_int_types + np_unsigned_int_types
34
+
35
+ np_float_types = [np.float16, np.float32, np.float64]
36
+
37
+ np_scalar_types = np_int_types + np_float_types
38
+
39
+
40
+ def randvals(rng, shape, dtype):
41
+ if dtype in np_float_types:
42
+ return rng.standard_normal(size=shape).astype(dtype)
43
+ elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
44
+ return rng.integers(1, high=3, size=shape, dtype=dtype)
45
+ return rng.integers(1, high=5, size=shape, dtype=dtype)
46
+
47
+
48
+ kernel_cache = dict()
49
+
50
+
51
+ def getkernel(func, suffix=""):
52
+ key = func.__name__ + "_" + suffix
53
+ if key not in kernel_cache:
54
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
55
+ return kernel_cache[key]
56
+
57
+
58
+ def get_select_kernel(dtype):
59
+ def output_select_kernel_fn(
60
+ input: wp.array(dtype=dtype),
61
+ index: int,
62
+ out: wp.array(dtype=dtype),
63
+ ):
64
+ out[0] = input[index]
65
+
66
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
67
+
68
+
69
+ def get_select_kernel2(dtype):
70
+ def output_select_kernel2_fn(
71
+ input: wp.array(dtype=dtype, ndim=2),
72
+ index0: int,
73
+ index1: int,
74
+ out: wp.array(dtype=dtype),
75
+ ):
76
+ out[0] = input[index0, index1]
77
+
78
+ return getkernel(output_select_kernel2_fn, suffix=dtype.__name__)
79
+
80
+
81
+ def test_arrays(test, device, dtype):
82
+ rng = np.random.default_rng(123)
83
+
84
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
85
+ vec2 = wp.types.vector(length=2, dtype=wptype)
86
+ vec3 = wp.types.vector(length=3, dtype=wptype)
87
+ vec4 = wp.types.vector(length=4, dtype=wptype)
88
+ vec5 = wp.types.vector(length=5, dtype=wptype)
89
+
90
+ v2_np = randvals(rng, (10, 2), dtype)
91
+ v3_np = randvals(rng, (10, 3), dtype)
92
+ v4_np = randvals(rng, (10, 4), dtype)
93
+ v5_np = randvals(rng, (10, 5), dtype)
94
+
95
+ v2 = wp.array(v2_np, dtype=vec2, requires_grad=True, device=device)
96
+ v3 = wp.array(v3_np, dtype=vec3, requires_grad=True, device=device)
97
+ v4 = wp.array(v4_np, dtype=vec4, requires_grad=True, device=device)
98
+ v5 = wp.array(v5_np, dtype=vec5, requires_grad=True, device=device)
99
+
100
+ assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
101
+ assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
102
+ assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
103
+ assert_np_equal(v5.numpy(), v5_np, tol=1.0e-6)
104
+
105
+ vec2 = wp.types.vector(length=2, dtype=wptype)
106
+ vec3 = wp.types.vector(length=3, dtype=wptype)
107
+ vec4 = wp.types.vector(length=4, dtype=wptype)
108
+
109
+ v2 = wp.array(v2_np, dtype=vec2, requires_grad=True, device=device)
110
+ v3 = wp.array(v3_np, dtype=vec3, requires_grad=True, device=device)
111
+ v4 = wp.array(v4_np, dtype=vec4, requires_grad=True, device=device)
112
+
113
+ assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
114
+ assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
115
+ assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
116
+
117
+
118
+ def test_components(test, device, dtype):
119
+ # test accessing vector components from Python - this is especially important
120
+ # for float16, which requires special handling internally
121
+
122
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
123
+ vec3 = wp.types.vector(length=3, dtype=wptype)
124
+
125
+ v = vec3(1, 2, 3)
126
+
127
+ # test __getitem__ for individual components
128
+ test.assertEqual(v[0], 1)
129
+ test.assertEqual(v[1], 2)
130
+ test.assertEqual(v[2], 3)
131
+
132
+ # test __getitem__ for slices
133
+ s = v[:]
134
+ test.assertEqual(s[0], 1)
135
+ test.assertEqual(s[1], 2)
136
+ test.assertEqual(s[2], 3)
137
+
138
+ s = v[1:]
139
+ test.assertEqual(s[0], 2)
140
+ test.assertEqual(s[1], 3)
141
+
142
+ s = v[:2]
143
+ test.assertEqual(s[0], 1)
144
+ test.assertEqual(s[1], 2)
145
+
146
+ s = v[::2]
147
+ test.assertEqual(s[0], 1)
148
+ test.assertEqual(s[1], 3)
149
+
150
+ # test __setitem__ for individual components
151
+ v[0] = 4
152
+ v[1] = 5
153
+ v[2] = 6
154
+ test.assertEqual(v[0], 4)
155
+ test.assertEqual(v[1], 5)
156
+ test.assertEqual(v[2], 6)
157
+
158
+ # test __setitem__ for slices
159
+ v[:] = [7, 8, 9]
160
+ test.assertEqual(v[0], 7)
161
+ test.assertEqual(v[1], 8)
162
+ test.assertEqual(v[2], 9)
163
+
164
+ v[1:] = [10, 11]
165
+ test.assertEqual(v[0], 7)
166
+ test.assertEqual(v[1], 10)
167
+ test.assertEqual(v[2], 11)
168
+
169
+ v[:2] = [12, 13]
170
+ test.assertEqual(v[0], 12)
171
+ test.assertEqual(v[1], 13)
172
+ test.assertEqual(v[2], 11)
173
+
174
+ v[::2] = [14, 15]
175
+ test.assertEqual(v[0], 14)
176
+ test.assertEqual(v[1], 13)
177
+ test.assertEqual(v[2], 15)
178
+
179
+
180
+ def test_py_arithmetic_ops(test, device, dtype):
181
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
182
+
183
+ def make_vec(*args):
184
+ if wptype in wp.types.int_types:
185
+ # Cast to the correct integer type to simulate wrapping.
186
+ return tuple(wptype._type_(x).value for x in args)
187
+
188
+ return args
189
+
190
+ vec_cls = wp.vec(3, wptype)
191
+
192
+ v = vec_cls(1, -2, 3)
193
+ test.assertSequenceEqual(+v, make_vec(1, -2, 3))
194
+ test.assertSequenceEqual(-v, make_vec(-1, 2, -3))
195
+ test.assertSequenceEqual(v + vec_cls(5, 5, 5), make_vec(6, 3, 8))
196
+ test.assertSequenceEqual(v - vec_cls(5, 5, 5), make_vec(-4, -7, -2))
197
+
198
+ v = vec_cls(2, 4, 6)
199
+ test.assertSequenceEqual(v * wptype(2), make_vec(4, 8, 12))
200
+ test.assertSequenceEqual(wptype(2) * v, make_vec(4, 8, 12))
201
+ test.assertSequenceEqual(v / wptype(2), make_vec(1, 2, 3))
202
+ test.assertSequenceEqual(wptype(24) / v, make_vec(12, 6, 4))
203
+
204
+
205
+ def test_constructors(test, device, dtype, register_kernels=False):
206
+ rng = np.random.default_rng(123)
207
+
208
+ tol = {
209
+ np.float16: 5.0e-3,
210
+ np.float32: 1.0e-6,
211
+ np.float64: 1.0e-8,
212
+ }.get(dtype, 0)
213
+
214
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
215
+ vec2 = wp.types.vector(length=2, dtype=wptype)
216
+ vec3 = wp.types.vector(length=3, dtype=wptype)
217
+ vec4 = wp.types.vector(length=4, dtype=wptype)
218
+ vec5 = wp.types.vector(length=5, dtype=wptype)
219
+
220
+ def check_scalar_constructor(
221
+ input: wp.array(dtype=wptype),
222
+ v2: wp.array(dtype=vec2),
223
+ v3: wp.array(dtype=vec3),
224
+ v4: wp.array(dtype=vec4),
225
+ v5: wp.array(dtype=vec5),
226
+ v20: wp.array(dtype=wptype),
227
+ v21: wp.array(dtype=wptype),
228
+ v30: wp.array(dtype=wptype),
229
+ v31: wp.array(dtype=wptype),
230
+ v32: wp.array(dtype=wptype),
231
+ v40: wp.array(dtype=wptype),
232
+ v41: wp.array(dtype=wptype),
233
+ v42: wp.array(dtype=wptype),
234
+ v43: wp.array(dtype=wptype),
235
+ v50: wp.array(dtype=wptype),
236
+ v51: wp.array(dtype=wptype),
237
+ v52: wp.array(dtype=wptype),
238
+ v53: wp.array(dtype=wptype),
239
+ v54: wp.array(dtype=wptype),
240
+ ):
241
+ v2result = vec2(input[0])
242
+ v3result = vec3(input[0])
243
+ v4result = vec4(input[0])
244
+ v5result = vec5(input[0])
245
+
246
+ v2[0] = v2result
247
+ v3[0] = v3result
248
+ v4[0] = v4result
249
+ v5[0] = v5result
250
+
251
+ # multiply outputs by 2 so we've got something to backpropagate
252
+ v20[0] = wptype(2) * v2result[0]
253
+ v21[0] = wptype(2) * v2result[1]
254
+
255
+ v30[0] = wptype(2) * v3result[0]
256
+ v31[0] = wptype(2) * v3result[1]
257
+ v32[0] = wptype(2) * v3result[2]
258
+
259
+ v40[0] = wptype(2) * v4result[0]
260
+ v41[0] = wptype(2) * v4result[1]
261
+ v42[0] = wptype(2) * v4result[2]
262
+ v43[0] = wptype(2) * v4result[3]
263
+
264
+ v50[0] = wptype(2) * v5result[0]
265
+ v51[0] = wptype(2) * v5result[1]
266
+ v52[0] = wptype(2) * v5result[2]
267
+ v53[0] = wptype(2) * v5result[3]
268
+ v54[0] = wptype(2) * v5result[4]
269
+
270
+ def check_vector_constructors(
271
+ input: wp.array(dtype=wptype),
272
+ v2: wp.array(dtype=vec2),
273
+ v3: wp.array(dtype=vec3),
274
+ v4: wp.array(dtype=vec4),
275
+ v5: wp.array(dtype=vec5),
276
+ v20: wp.array(dtype=wptype),
277
+ v21: wp.array(dtype=wptype),
278
+ v30: wp.array(dtype=wptype),
279
+ v31: wp.array(dtype=wptype),
280
+ v32: wp.array(dtype=wptype),
281
+ v40: wp.array(dtype=wptype),
282
+ v41: wp.array(dtype=wptype),
283
+ v42: wp.array(dtype=wptype),
284
+ v43: wp.array(dtype=wptype),
285
+ v50: wp.array(dtype=wptype),
286
+ v51: wp.array(dtype=wptype),
287
+ v52: wp.array(dtype=wptype),
288
+ v53: wp.array(dtype=wptype),
289
+ v54: wp.array(dtype=wptype),
290
+ ):
291
+ v2result = vec2(input[0], input[1])
292
+ v3result = vec3(input[2], input[3], input[4])
293
+ v4result = vec4(input[5], input[6], input[7], input[8])
294
+ v5result = vec5(input[9], input[10], input[11], input[12], input[13])
295
+
296
+ v2[0] = v2result
297
+ v3[0] = v3result
298
+ v4[0] = v4result
299
+ v5[0] = v5result
300
+
301
+ # multiply the output by 2 so we've got something to backpropagate:
302
+ v20[0] = wptype(2) * v2result[0]
303
+ v21[0] = wptype(2) * v2result[1]
304
+
305
+ v30[0] = wptype(2) * v3result[0]
306
+ v31[0] = wptype(2) * v3result[1]
307
+ v32[0] = wptype(2) * v3result[2]
308
+
309
+ v40[0] = wptype(2) * v4result[0]
310
+ v41[0] = wptype(2) * v4result[1]
311
+ v42[0] = wptype(2) * v4result[2]
312
+ v43[0] = wptype(2) * v4result[3]
313
+
314
+ v50[0] = wptype(2) * v5result[0]
315
+ v51[0] = wptype(2) * v5result[1]
316
+ v52[0] = wptype(2) * v5result[2]
317
+ v53[0] = wptype(2) * v5result[3]
318
+ v54[0] = wptype(2) * v5result[4]
319
+
320
+ vec_kernel = getkernel(check_vector_constructors, suffix=dtype.__name__)
321
+ kernel = getkernel(check_scalar_constructor, suffix=dtype.__name__)
322
+
323
+ if register_kernels:
324
+ return
325
+
326
+ input = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
327
+ v2 = wp.zeros(1, dtype=vec2, device=device)
328
+ v3 = wp.zeros(1, dtype=vec3, device=device)
329
+ v4 = wp.zeros(1, dtype=vec4, device=device)
330
+ v5 = wp.zeros(1, dtype=vec5, device=device)
331
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
332
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
333
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
334
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
335
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
336
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
337
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
338
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
339
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
340
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
341
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
342
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
343
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
344
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
345
+
346
+ tape = wp.Tape()
347
+ with tape:
348
+ wp.launch(
349
+ kernel,
350
+ dim=1,
351
+ inputs=[input],
352
+ outputs=[v2, v3, v4, v5, v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
353
+ device=device,
354
+ )
355
+
356
+ if dtype in np_float_types:
357
+ for l in [v20, v21]:
358
+ tape.backward(loss=l)
359
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
360
+ tape.zero()
361
+
362
+ for l in [v30, v31, v32]:
363
+ tape.backward(loss=l)
364
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
365
+ tape.zero()
366
+
367
+ for l in [v40, v41, v42, v43]:
368
+ tape.backward(loss=l)
369
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
370
+ tape.zero()
371
+
372
+ for l in [v50, v51, v52, v53, v54]:
373
+ tape.backward(loss=l)
374
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
375
+ tape.zero()
376
+
377
+ val = input.numpy()[0]
378
+ assert_np_equal(v2.numpy()[0], np.array([val, val]), tol=1.0e-6)
379
+ assert_np_equal(v3.numpy()[0], np.array([val, val, val]), tol=1.0e-6)
380
+ assert_np_equal(v4.numpy()[0], np.array([val, val, val, val]), tol=1.0e-6)
381
+ assert_np_equal(v5.numpy()[0], np.array([val, val, val, val, val]), tol=1.0e-6)
382
+
383
+ assert_np_equal(v20.numpy()[0], 2 * val, tol=1.0e-6)
384
+ assert_np_equal(v21.numpy()[0], 2 * val, tol=1.0e-6)
385
+ assert_np_equal(v30.numpy()[0], 2 * val, tol=1.0e-6)
386
+ assert_np_equal(v31.numpy()[0], 2 * val, tol=1.0e-6)
387
+ assert_np_equal(v32.numpy()[0], 2 * val, tol=1.0e-6)
388
+ assert_np_equal(v40.numpy()[0], 2 * val, tol=1.0e-6)
389
+ assert_np_equal(v41.numpy()[0], 2 * val, tol=1.0e-6)
390
+ assert_np_equal(v42.numpy()[0], 2 * val, tol=1.0e-6)
391
+ assert_np_equal(v43.numpy()[0], 2 * val, tol=1.0e-6)
392
+ assert_np_equal(v50.numpy()[0], 2 * val, tol=1.0e-6)
393
+ assert_np_equal(v51.numpy()[0], 2 * val, tol=1.0e-6)
394
+ assert_np_equal(v52.numpy()[0], 2 * val, tol=1.0e-6)
395
+ assert_np_equal(v53.numpy()[0], 2 * val, tol=1.0e-6)
396
+ assert_np_equal(v54.numpy()[0], 2 * val, tol=1.0e-6)
397
+
398
+ input = wp.array(randvals(rng, [14], dtype), requires_grad=True, device=device)
399
+ tape = wp.Tape()
400
+ with tape:
401
+ wp.launch(
402
+ vec_kernel,
403
+ dim=1,
404
+ inputs=[input],
405
+ outputs=[v2, v3, v4, v5, v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
406
+ device=device,
407
+ )
408
+
409
+ if dtype in np_float_types:
410
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
411
+ tape.backward(loss=l)
412
+ grad = tape.gradients[input].numpy()
413
+ expected_grad = np.zeros_like(grad)
414
+ expected_grad[i] = 2
415
+ assert_np_equal(grad, expected_grad, tol=tol)
416
+ tape.zero()
417
+
418
+ assert_np_equal(v2.numpy()[0, 0], input.numpy()[0], tol=tol)
419
+ assert_np_equal(v2.numpy()[0, 1], input.numpy()[1], tol=tol)
420
+ assert_np_equal(v3.numpy()[0, 0], input.numpy()[2], tol=tol)
421
+ assert_np_equal(v3.numpy()[0, 1], input.numpy()[3], tol=tol)
422
+ assert_np_equal(v3.numpy()[0, 2], input.numpy()[4], tol=tol)
423
+ assert_np_equal(v4.numpy()[0, 0], input.numpy()[5], tol=tol)
424
+ assert_np_equal(v4.numpy()[0, 1], input.numpy()[6], tol=tol)
425
+ assert_np_equal(v4.numpy()[0, 2], input.numpy()[7], tol=tol)
426
+ assert_np_equal(v4.numpy()[0, 3], input.numpy()[8], tol=tol)
427
+ assert_np_equal(v5.numpy()[0, 0], input.numpy()[9], tol=tol)
428
+ assert_np_equal(v5.numpy()[0, 1], input.numpy()[10], tol=tol)
429
+ assert_np_equal(v5.numpy()[0, 2], input.numpy()[11], tol=tol)
430
+ assert_np_equal(v5.numpy()[0, 3], input.numpy()[12], tol=tol)
431
+ assert_np_equal(v5.numpy()[0, 4], input.numpy()[13], tol=tol)
432
+
433
+ assert_np_equal(v20.numpy()[0], 2 * input.numpy()[0], tol=tol)
434
+ assert_np_equal(v21.numpy()[0], 2 * input.numpy()[1], tol=tol)
435
+ assert_np_equal(v30.numpy()[0], 2 * input.numpy()[2], tol=tol)
436
+ assert_np_equal(v31.numpy()[0], 2 * input.numpy()[3], tol=tol)
437
+ assert_np_equal(v32.numpy()[0], 2 * input.numpy()[4], tol=tol)
438
+ assert_np_equal(v40.numpy()[0], 2 * input.numpy()[5], tol=tol)
439
+ assert_np_equal(v41.numpy()[0], 2 * input.numpy()[6], tol=tol)
440
+ assert_np_equal(v42.numpy()[0], 2 * input.numpy()[7], tol=tol)
441
+ assert_np_equal(v43.numpy()[0], 2 * input.numpy()[8], tol=tol)
442
+ assert_np_equal(v50.numpy()[0], 2 * input.numpy()[9], tol=tol)
443
+ assert_np_equal(v51.numpy()[0], 2 * input.numpy()[10], tol=tol)
444
+ assert_np_equal(v52.numpy()[0], 2 * input.numpy()[11], tol=tol)
445
+ assert_np_equal(v53.numpy()[0], 2 * input.numpy()[12], tol=tol)
446
+ assert_np_equal(v54.numpy()[0], 2 * input.numpy()[13], tol=tol)
447
+
448
+
449
+ def test_anon_type_instance(test, device, dtype, register_kernels=False):
450
+ rng = np.random.default_rng(123)
451
+
452
+ tol = {
453
+ np.float16: 5.0e-3,
454
+ np.float32: 1.0e-6,
455
+ np.float64: 1.0e-8,
456
+ }.get(dtype, 0)
457
+
458
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
459
+
460
+ def check_scalar_init(
461
+ input: wp.array(dtype=wptype),
462
+ output: wp.array(dtype=wptype),
463
+ ):
464
+ v2result = wp.vector(input[0], length=2)
465
+ v3result = wp.vector(input[1], length=3)
466
+ v4result = wp.vector(input[2], length=4)
467
+ v5result = wp.vector(input[3], length=5)
468
+
469
+ idx = 0
470
+ for i in range(2):
471
+ output[idx] = wptype(2) * v2result[i]
472
+ idx = idx + 1
473
+ for i in range(3):
474
+ output[idx] = wptype(2) * v3result[i]
475
+ idx = idx + 1
476
+ for i in range(4):
477
+ output[idx] = wptype(2) * v4result[i]
478
+ idx = idx + 1
479
+ for i in range(5):
480
+ output[idx] = wptype(2) * v5result[i]
481
+ idx = idx + 1
482
+
483
+ def check_component_init(
484
+ input: wp.array(dtype=wptype),
485
+ output: wp.array(dtype=wptype),
486
+ ):
487
+ v2result = wp.vector(input[0], input[1])
488
+ v3result = wp.vector(input[2], input[3], input[4])
489
+ v4result = wp.vector(input[5], input[6], input[7], input[8])
490
+ v5result = wp.vector(input[9], input[10], input[11], input[12], input[13])
491
+
492
+ idx = 0
493
+ for i in range(2):
494
+ output[idx] = wptype(2) * v2result[i]
495
+ idx = idx + 1
496
+ for i in range(3):
497
+ output[idx] = wptype(2) * v3result[i]
498
+ idx = idx + 1
499
+ for i in range(4):
500
+ output[idx] = wptype(2) * v4result[i]
501
+ idx = idx + 1
502
+ for i in range(5):
503
+ output[idx] = wptype(2) * v5result[i]
504
+ idx = idx + 1
505
+
506
+ scalar_kernel = getkernel(check_scalar_init, suffix=dtype.__name__)
507
+ component_kernel = getkernel(check_component_init, suffix=dtype.__name__)
508
+ output_select_kernel = get_select_kernel(wptype)
509
+
510
+ if register_kernels:
511
+ return
512
+
513
+ input = wp.array(randvals(rng, [4], dtype), requires_grad=True, device=device)
514
+ output = wp.zeros(2 + 3 + 4 + 5, dtype=wptype, requires_grad=True, device=device)
515
+
516
+ wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
517
+
518
+ assert_np_equal(output.numpy()[:2], 2 * np.array([input.numpy()[0]] * 2), tol=1.0e-6)
519
+ assert_np_equal(output.numpy()[2:5], 2 * np.array([input.numpy()[1]] * 3), tol=1.0e-6)
520
+ assert_np_equal(output.numpy()[5:9], 2 * np.array([input.numpy()[2]] * 4), tol=1.0e-6)
521
+ assert_np_equal(output.numpy()[9:], 2 * np.array([input.numpy()[3]] * 5), tol=1.0e-6)
522
+
523
+ if dtype in np_float_types:
524
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
525
+ for i in range(len(output)):
526
+ tape = wp.Tape()
527
+ with tape:
528
+ wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
529
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
530
+
531
+ tape.backward(loss=out)
532
+ expected = np.zeros_like(input.numpy())
533
+ if i < 2:
534
+ expected[0] = 2
535
+ elif i < 5:
536
+ expected[1] = 2
537
+ elif i < 9:
538
+ expected[2] = 2
539
+ else:
540
+ expected[3] = 2
541
+
542
+ assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
543
+
544
+ tape.reset()
545
+ tape.zero()
546
+
547
+ input = wp.array(randvals(rng, [2 + 3 + 4 + 5], dtype), requires_grad=True, device=device)
548
+ output = wp.zeros(2 + 3 + 4 + 5, dtype=wptype, requires_grad=True, device=device)
549
+
550
+ wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
551
+
552
+ assert_np_equal(output.numpy(), 2 * input.numpy(), tol=1.0e-6)
553
+
554
+ if dtype in np_float_types:
555
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
556
+ for i in range(len(output)):
557
+ tape = wp.Tape()
558
+ with tape:
559
+ wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
560
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
561
+
562
+ tape.backward(loss=out)
563
+ expected = np.zeros_like(input.numpy())
564
+ expected[i] = 2
565
+
566
+ assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
567
+
568
+ tape.reset()
569
+ tape.zero()
570
+
571
+
572
+ def test_indexing(test, device, dtype, register_kernels=False):
573
+ rng = np.random.default_rng(123)
574
+
575
+ tol = {
576
+ np.float16: 5.0e-3,
577
+ np.float32: 1.0e-6,
578
+ np.float64: 1.0e-8,
579
+ }.get(dtype, 0)
580
+
581
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
582
+ vec2 = wp.types.vector(length=2, dtype=wptype)
583
+ vec3 = wp.types.vector(length=3, dtype=wptype)
584
+ vec4 = wp.types.vector(length=4, dtype=wptype)
585
+ vec5 = wp.types.vector(length=5, dtype=wptype)
586
+
587
+ def check_indexing(
588
+ v2: wp.array(dtype=vec2),
589
+ v3: wp.array(dtype=vec3),
590
+ v4: wp.array(dtype=vec4),
591
+ v5: wp.array(dtype=vec5),
592
+ v20: wp.array(dtype=wptype),
593
+ v21: wp.array(dtype=wptype),
594
+ v30: wp.array(dtype=wptype),
595
+ v31: wp.array(dtype=wptype),
596
+ v32: wp.array(dtype=wptype),
597
+ v40: wp.array(dtype=wptype),
598
+ v41: wp.array(dtype=wptype),
599
+ v42: wp.array(dtype=wptype),
600
+ v43: wp.array(dtype=wptype),
601
+ v50: wp.array(dtype=wptype),
602
+ v51: wp.array(dtype=wptype),
603
+ v52: wp.array(dtype=wptype),
604
+ v53: wp.array(dtype=wptype),
605
+ v54: wp.array(dtype=wptype),
606
+ ):
607
+ # multiply outputs by 2 so we've got something to backpropagate:
608
+ v20[0] = wptype(2) * v2[0][0]
609
+ v21[0] = wptype(2) * v2[0][1]
610
+
611
+ v30[0] = wptype(2) * v3[0][0]
612
+ v31[0] = wptype(2) * v3[0][1]
613
+ v32[0] = wptype(2) * v3[0][2]
614
+
615
+ v40[0] = wptype(2) * v4[0][0]
616
+ v41[0] = wptype(2) * v4[0][1]
617
+ v42[0] = wptype(2) * v4[0][2]
618
+ v43[0] = wptype(2) * v4[0][3]
619
+
620
+ v50[0] = wptype(2) * v5[0][0]
621
+ v51[0] = wptype(2) * v5[0][1]
622
+ v52[0] = wptype(2) * v5[0][2]
623
+ v53[0] = wptype(2) * v5[0][3]
624
+ v54[0] = wptype(2) * v5[0][4]
625
+
626
+ kernel = getkernel(check_indexing, suffix=dtype.__name__)
627
+
628
+ if register_kernels:
629
+ return
630
+
631
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
632
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
633
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
634
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
635
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
636
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
637
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
638
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
639
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
640
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
641
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
642
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
643
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
644
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
645
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
646
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
647
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
648
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
649
+
650
+ tape = wp.Tape()
651
+ with tape:
652
+ wp.launch(
653
+ kernel,
654
+ dim=1,
655
+ inputs=[v2, v3, v4, v5],
656
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
657
+ device=device,
658
+ )
659
+
660
+ if dtype in np_float_types:
661
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
662
+ tape.backward(loss=l)
663
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
664
+ expected_grads = np.zeros_like(allgrads)
665
+ expected_grads[i] = 2
666
+ assert_np_equal(allgrads, expected_grads, tol=tol)
667
+ tape.zero()
668
+
669
+ assert_np_equal(v20.numpy()[0], 2.0 * v2.numpy()[0, 0], tol=tol)
670
+ assert_np_equal(v21.numpy()[0], 2.0 * v2.numpy()[0, 1], tol=tol)
671
+ assert_np_equal(v30.numpy()[0], 2.0 * v3.numpy()[0, 0], tol=tol)
672
+ assert_np_equal(v31.numpy()[0], 2.0 * v3.numpy()[0, 1], tol=tol)
673
+ assert_np_equal(v32.numpy()[0], 2.0 * v3.numpy()[0, 2], tol=tol)
674
+ assert_np_equal(v40.numpy()[0], 2.0 * v4.numpy()[0, 0], tol=tol)
675
+ assert_np_equal(v41.numpy()[0], 2.0 * v4.numpy()[0, 1], tol=tol)
676
+ assert_np_equal(v42.numpy()[0], 2.0 * v4.numpy()[0, 2], tol=tol)
677
+ assert_np_equal(v43.numpy()[0], 2.0 * v4.numpy()[0, 3], tol=tol)
678
+ assert_np_equal(v50.numpy()[0], 2.0 * v5.numpy()[0, 0], tol=tol)
679
+ assert_np_equal(v51.numpy()[0], 2.0 * v5.numpy()[0, 1], tol=tol)
680
+ assert_np_equal(v52.numpy()[0], 2.0 * v5.numpy()[0, 2], tol=tol)
681
+ assert_np_equal(v53.numpy()[0], 2.0 * v5.numpy()[0, 3], tol=tol)
682
+ assert_np_equal(v54.numpy()[0], 2.0 * v5.numpy()[0, 4], tol=tol)
683
+
684
+
685
+ def test_equality(test, device, dtype, register_kernels=False):
686
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
687
+ vec2 = wp.types.vector(length=2, dtype=wptype)
688
+ vec3 = wp.types.vector(length=3, dtype=wptype)
689
+ vec4 = wp.types.vector(length=4, dtype=wptype)
690
+ vec5 = wp.types.vector(length=5, dtype=wptype)
691
+
692
+ def check_equality(
693
+ v20: wp.array(dtype=vec2),
694
+ v21: wp.array(dtype=vec2),
695
+ v22: wp.array(dtype=vec2),
696
+ v30: wp.array(dtype=vec3),
697
+ v31: wp.array(dtype=vec3),
698
+ v32: wp.array(dtype=vec3),
699
+ v33: wp.array(dtype=vec3),
700
+ v40: wp.array(dtype=vec4),
701
+ v41: wp.array(dtype=vec4),
702
+ v42: wp.array(dtype=vec4),
703
+ v43: wp.array(dtype=vec4),
704
+ v44: wp.array(dtype=vec4),
705
+ v50: wp.array(dtype=vec5),
706
+ v51: wp.array(dtype=vec5),
707
+ v52: wp.array(dtype=vec5),
708
+ v53: wp.array(dtype=vec5),
709
+ v54: wp.array(dtype=vec5),
710
+ v55: wp.array(dtype=vec5),
711
+ ):
712
+ wp.expect_eq(v20[0], v20[0])
713
+ wp.expect_neq(v21[0], v20[0])
714
+ wp.expect_neq(v22[0], v20[0])
715
+
716
+ wp.expect_eq(v30[0], v30[0])
717
+ wp.expect_neq(v31[0], v30[0])
718
+ wp.expect_neq(v32[0], v30[0])
719
+ wp.expect_neq(v33[0], v30[0])
720
+
721
+ wp.expect_eq(v40[0], v40[0])
722
+ wp.expect_neq(v41[0], v40[0])
723
+ wp.expect_neq(v42[0], v40[0])
724
+ wp.expect_neq(v43[0], v40[0])
725
+ wp.expect_neq(v44[0], v40[0])
726
+
727
+ wp.expect_eq(v50[0], v50[0])
728
+ wp.expect_neq(v51[0], v50[0])
729
+ wp.expect_neq(v52[0], v50[0])
730
+ wp.expect_neq(v53[0], v50[0])
731
+ wp.expect_neq(v54[0], v50[0])
732
+ wp.expect_neq(v55[0], v50[0])
733
+
734
+ kernel = getkernel(check_equality, suffix=dtype.__name__)
735
+
736
+ if register_kernels:
737
+ return
738
+
739
+ v20 = wp.array([1.0, 2.0], dtype=vec2, requires_grad=True, device=device)
740
+ v21 = wp.array([1.0, 3.0], dtype=vec2, requires_grad=True, device=device)
741
+ v22 = wp.array([3.0, 2.0], dtype=vec2, requires_grad=True, device=device)
742
+
743
+ v30 = wp.array([1.0, 2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
744
+ v31 = wp.array([-1.0, 2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
745
+ v32 = wp.array([1.0, -2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
746
+ v33 = wp.array([1.0, 2.0, -3.0], dtype=vec3, requires_grad=True, device=device)
747
+
748
+ v40 = wp.array([1.0, 2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
749
+ v41 = wp.array([-1.0, 2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
750
+ v42 = wp.array([1.0, -2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
751
+ v43 = wp.array([1.0, 2.0, -3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
752
+ v44 = wp.array([1.0, 2.0, 3.0, -4.0], dtype=vec4, requires_grad=True, device=device)
753
+
754
+ v50 = wp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
755
+ v51 = wp.array([-1.0, 2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
756
+ v52 = wp.array([1.0, -2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
757
+ v53 = wp.array([1.0, 2.0, -3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
758
+ v54 = wp.array([1.0, 2.0, 3.0, -4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
759
+ v55 = wp.array([1.0, 2.0, 3.0, 4.0, -5.0], dtype=vec5, requires_grad=True, device=device)
760
+ wp.launch(
761
+ kernel,
762
+ dim=1,
763
+ inputs=[
764
+ v20,
765
+ v21,
766
+ v22,
767
+ v30,
768
+ v31,
769
+ v32,
770
+ v33,
771
+ v40,
772
+ v41,
773
+ v42,
774
+ v43,
775
+ v44,
776
+ v50,
777
+ v51,
778
+ v52,
779
+ v53,
780
+ v54,
781
+ v55,
782
+ ],
783
+ outputs=[],
784
+ device=device,
785
+ )
786
+
787
+
788
+ def test_scalar_multiplication(test, device, dtype, register_kernels=False):
789
+ rng = np.random.default_rng(123)
790
+
791
+ tol = {
792
+ np.float16: 5.0e-3,
793
+ np.float32: 1.0e-6,
794
+ np.float64: 1.0e-8,
795
+ }.get(dtype, 0)
796
+
797
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
798
+ vec2 = wp.types.vector(length=2, dtype=wptype)
799
+ vec3 = wp.types.vector(length=3, dtype=wptype)
800
+ vec4 = wp.types.vector(length=4, dtype=wptype)
801
+ vec5 = wp.types.vector(length=5, dtype=wptype)
802
+
803
+ def check_mul(
804
+ s: wp.array(dtype=wptype),
805
+ v2: wp.array(dtype=vec2),
806
+ v3: wp.array(dtype=vec3),
807
+ v4: wp.array(dtype=vec4),
808
+ v5: wp.array(dtype=vec5),
809
+ v20: wp.array(dtype=wptype),
810
+ v21: wp.array(dtype=wptype),
811
+ v30: wp.array(dtype=wptype),
812
+ v31: wp.array(dtype=wptype),
813
+ v32: wp.array(dtype=wptype),
814
+ v40: wp.array(dtype=wptype),
815
+ v41: wp.array(dtype=wptype),
816
+ v42: wp.array(dtype=wptype),
817
+ v43: wp.array(dtype=wptype),
818
+ v50: wp.array(dtype=wptype),
819
+ v51: wp.array(dtype=wptype),
820
+ v52: wp.array(dtype=wptype),
821
+ v53: wp.array(dtype=wptype),
822
+ v54: wp.array(dtype=wptype),
823
+ ):
824
+ v2result = s[0] * v2[0]
825
+ v3result = s[0] * v3[0]
826
+ v4result = s[0] * v4[0]
827
+ v5result = s[0] * v5[0]
828
+
829
+ # multiply outputs by 2 so we've got something to backpropagate:
830
+ v20[0] = wptype(2) * v2result[0]
831
+ v21[0] = wptype(2) * v2result[1]
832
+
833
+ v30[0] = wptype(2) * v3result[0]
834
+ v31[0] = wptype(2) * v3result[1]
835
+ v32[0] = wptype(2) * v3result[2]
836
+
837
+ v40[0] = wptype(2) * v4result[0]
838
+ v41[0] = wptype(2) * v4result[1]
839
+ v42[0] = wptype(2) * v4result[2]
840
+ v43[0] = wptype(2) * v4result[3]
841
+
842
+ v50[0] = wptype(2) * v5result[0]
843
+ v51[0] = wptype(2) * v5result[1]
844
+ v52[0] = wptype(2) * v5result[2]
845
+ v53[0] = wptype(2) * v5result[3]
846
+ v54[0] = wptype(2) * v5result[4]
847
+
848
+ kernel = getkernel(check_mul, suffix=dtype.__name__)
849
+
850
+ if register_kernels:
851
+ return
852
+
853
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
854
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
855
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
856
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
857
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
858
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
859
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
860
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
861
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
862
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
863
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
864
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
865
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
866
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
867
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
868
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
869
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
870
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
871
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
872
+ tape = wp.Tape()
873
+ with tape:
874
+ wp.launch(
875
+ kernel,
876
+ dim=1,
877
+ inputs=[
878
+ s,
879
+ v2,
880
+ v3,
881
+ v4,
882
+ v5,
883
+ ],
884
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
885
+ device=device,
886
+ )
887
+
888
+ assert_np_equal(v20.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 0], tol=tol)
889
+ assert_np_equal(v21.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 1], tol=tol)
890
+
891
+ assert_np_equal(v30.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 0], tol=10 * tol)
892
+ assert_np_equal(v31.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 1], tol=10 * tol)
893
+ assert_np_equal(v32.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 2], tol=10 * tol)
894
+
895
+ assert_np_equal(v40.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 0], tol=10 * tol)
896
+ assert_np_equal(v41.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 1], tol=10 * tol)
897
+ assert_np_equal(v42.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 2], tol=10 * tol)
898
+ assert_np_equal(v43.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 3], tol=10 * tol)
899
+
900
+ assert_np_equal(v50.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 0], tol=10 * tol)
901
+ assert_np_equal(v51.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 1], tol=10 * tol)
902
+ assert_np_equal(v52.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 2], tol=10 * tol)
903
+ assert_np_equal(v53.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 3], tol=10 * tol)
904
+ assert_np_equal(v54.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 4], tol=10 * tol)
905
+
906
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
907
+
908
+ if dtype in np_float_types:
909
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43]):
910
+ tape.backward(loss=l)
911
+ sgrad = tape.gradients[s].numpy()[0]
912
+ assert_np_equal(sgrad, 2 * incmps[i], tol=10 * tol)
913
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4]])
914
+ expected_grads = np.zeros_like(allgrads)
915
+ expected_grads[i] = s.numpy()[0] * 2
916
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
917
+ tape.zero()
918
+
919
+
920
+ def test_scalar_multiplication_rightmul(test, device, dtype, register_kernels=False):
921
+ rng = np.random.default_rng(123)
922
+
923
+ tol = {
924
+ np.float16: 5.0e-3,
925
+ np.float32: 1.0e-6,
926
+ np.float64: 1.0e-8,
927
+ }.get(dtype, 0)
928
+
929
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
930
+ vec2 = wp.types.vector(length=2, dtype=wptype)
931
+ vec3 = wp.types.vector(length=3, dtype=wptype)
932
+ vec4 = wp.types.vector(length=4, dtype=wptype)
933
+ vec5 = wp.types.vector(length=5, dtype=wptype)
934
+
935
+ def check_rightmul(
936
+ s: wp.array(dtype=wptype),
937
+ v2: wp.array(dtype=vec2),
938
+ v3: wp.array(dtype=vec3),
939
+ v4: wp.array(dtype=vec4),
940
+ v5: wp.array(dtype=vec5),
941
+ v20: wp.array(dtype=wptype),
942
+ v21: wp.array(dtype=wptype),
943
+ v30: wp.array(dtype=wptype),
944
+ v31: wp.array(dtype=wptype),
945
+ v32: wp.array(dtype=wptype),
946
+ v40: wp.array(dtype=wptype),
947
+ v41: wp.array(dtype=wptype),
948
+ v42: wp.array(dtype=wptype),
949
+ v43: wp.array(dtype=wptype),
950
+ v50: wp.array(dtype=wptype),
951
+ v51: wp.array(dtype=wptype),
952
+ v52: wp.array(dtype=wptype),
953
+ v53: wp.array(dtype=wptype),
954
+ v54: wp.array(dtype=wptype),
955
+ ):
956
+ v2result = v2[0] * s[0]
957
+ v3result = v3[0] * s[0]
958
+ v4result = v4[0] * s[0]
959
+ v5result = v5[0] * s[0]
960
+
961
+ # multiply outputs by 2 so we've got something to backpropagate:
962
+ v20[0] = wptype(2) * v2result[0]
963
+ v21[0] = wptype(2) * v2result[1]
964
+
965
+ v30[0] = wptype(2) * v3result[0]
966
+ v31[0] = wptype(2) * v3result[1]
967
+ v32[0] = wptype(2) * v3result[2]
968
+
969
+ v40[0] = wptype(2) * v4result[0]
970
+ v41[0] = wptype(2) * v4result[1]
971
+ v42[0] = wptype(2) * v4result[2]
972
+ v43[0] = wptype(2) * v4result[3]
973
+
974
+ v50[0] = wptype(2) * v5result[0]
975
+ v51[0] = wptype(2) * v5result[1]
976
+ v52[0] = wptype(2) * v5result[2]
977
+ v53[0] = wptype(2) * v5result[3]
978
+ v54[0] = wptype(2) * v5result[4]
979
+
980
+ kernel = getkernel(check_rightmul, suffix=dtype.__name__)
981
+
982
+ if register_kernels:
983
+ return
984
+
985
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
986
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
987
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
988
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
989
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
990
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
991
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
992
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
993
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
994
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
995
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
996
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
997
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
998
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
999
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1000
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1001
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1002
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1003
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1004
+ tape = wp.Tape()
1005
+ with tape:
1006
+ wp.launch(
1007
+ kernel,
1008
+ dim=1,
1009
+ inputs=[
1010
+ s,
1011
+ v2,
1012
+ v3,
1013
+ v4,
1014
+ v5,
1015
+ ],
1016
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1017
+ device=device,
1018
+ )
1019
+
1020
+ assert_np_equal(v20.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 0], tol=tol)
1021
+ assert_np_equal(v21.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 1], tol=tol)
1022
+
1023
+ assert_np_equal(v30.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 0], tol=10 * tol)
1024
+ assert_np_equal(v31.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 1], tol=10 * tol)
1025
+ assert_np_equal(v32.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 2], tol=10 * tol)
1026
+
1027
+ assert_np_equal(v40.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 0], tol=10 * tol)
1028
+ assert_np_equal(v41.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 1], tol=10 * tol)
1029
+ assert_np_equal(v42.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 2], tol=10 * tol)
1030
+ assert_np_equal(v43.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 3], tol=10 * tol)
1031
+
1032
+ assert_np_equal(v50.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 0], tol=10 * tol)
1033
+ assert_np_equal(v51.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 1], tol=10 * tol)
1034
+ assert_np_equal(v52.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 2], tol=10 * tol)
1035
+ assert_np_equal(v53.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 3], tol=10 * tol)
1036
+ assert_np_equal(v54.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 4], tol=10 * tol)
1037
+
1038
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1039
+
1040
+ if dtype in np_float_types:
1041
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43]):
1042
+ tape.backward(loss=l)
1043
+ sgrad = tape.gradients[s].numpy()[0]
1044
+ assert_np_equal(sgrad, 2 * incmps[i], tol=10 * tol)
1045
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4]])
1046
+ expected_grads = np.zeros_like(allgrads)
1047
+ expected_grads[i] = s.numpy()[0] * 2
1048
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
1049
+ tape.zero()
1050
+
1051
+
1052
+ def test_cw_multiplication(test, device, dtype, register_kernels=False):
1053
+ rng = np.random.default_rng(123)
1054
+
1055
+ tol = {
1056
+ np.float16: 5.0e-3,
1057
+ np.float32: 1.0e-6,
1058
+ np.float64: 1.0e-8,
1059
+ }.get(dtype, 0)
1060
+
1061
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1062
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1063
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1064
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1065
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1066
+
1067
+ def check_cw_mul(
1068
+ s2: wp.array(dtype=vec2),
1069
+ s3: wp.array(dtype=vec3),
1070
+ s4: wp.array(dtype=vec4),
1071
+ s5: wp.array(dtype=vec5),
1072
+ v2: wp.array(dtype=vec2),
1073
+ v3: wp.array(dtype=vec3),
1074
+ v4: wp.array(dtype=vec4),
1075
+ v5: wp.array(dtype=vec5),
1076
+ v20: wp.array(dtype=wptype),
1077
+ v21: wp.array(dtype=wptype),
1078
+ v30: wp.array(dtype=wptype),
1079
+ v31: wp.array(dtype=wptype),
1080
+ v32: wp.array(dtype=wptype),
1081
+ v40: wp.array(dtype=wptype),
1082
+ v41: wp.array(dtype=wptype),
1083
+ v42: wp.array(dtype=wptype),
1084
+ v43: wp.array(dtype=wptype),
1085
+ v50: wp.array(dtype=wptype),
1086
+ v51: wp.array(dtype=wptype),
1087
+ v52: wp.array(dtype=wptype),
1088
+ v53: wp.array(dtype=wptype),
1089
+ v54: wp.array(dtype=wptype),
1090
+ ):
1091
+ v2result = wp.cw_mul(s2[0], v2[0])
1092
+ v3result = wp.cw_mul(s3[0], v3[0])
1093
+ v4result = wp.cw_mul(s4[0], v4[0])
1094
+ v5result = wp.cw_mul(s5[0], v5[0])
1095
+
1096
+ v20[0] = wptype(2) * v2result[0]
1097
+ v21[0] = wptype(2) * v2result[1]
1098
+
1099
+ v30[0] = wptype(2) * v3result[0]
1100
+ v31[0] = wptype(2) * v3result[1]
1101
+ v32[0] = wptype(2) * v3result[2]
1102
+
1103
+ v40[0] = wptype(2) * v4result[0]
1104
+ v41[0] = wptype(2) * v4result[1]
1105
+ v42[0] = wptype(2) * v4result[2]
1106
+ v43[0] = wptype(2) * v4result[3]
1107
+
1108
+ v50[0] = wptype(2) * v5result[0]
1109
+ v51[0] = wptype(2) * v5result[1]
1110
+ v52[0] = wptype(2) * v5result[2]
1111
+ v53[0] = wptype(2) * v5result[3]
1112
+ v54[0] = wptype(2) * v5result[4]
1113
+
1114
+ kernel = getkernel(check_cw_mul, suffix=dtype.__name__)
1115
+
1116
+ if register_kernels:
1117
+ return
1118
+
1119
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1120
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1121
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1122
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1123
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1124
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1125
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1126
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1127
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1128
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1129
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1130
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1131
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1132
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1133
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1134
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1135
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1136
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1137
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1138
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1139
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1140
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1141
+ tape = wp.Tape()
1142
+ with tape:
1143
+ wp.launch(
1144
+ kernel,
1145
+ dim=1,
1146
+ inputs=[
1147
+ s2,
1148
+ s3,
1149
+ s4,
1150
+ s5,
1151
+ v2,
1152
+ v3,
1153
+ v4,
1154
+ v5,
1155
+ ],
1156
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1157
+ device=device,
1158
+ )
1159
+
1160
+ assert_np_equal(v20.numpy()[0], 2 * s2.numpy()[0, 0] * v2.numpy()[0, 0], tol=10 * tol)
1161
+ assert_np_equal(v21.numpy()[0], 2 * s2.numpy()[0, 1] * v2.numpy()[0, 1], tol=10 * tol)
1162
+
1163
+ assert_np_equal(v30.numpy()[0], 2 * s3.numpy()[0, 0] * v3.numpy()[0, 0], tol=10 * tol)
1164
+ assert_np_equal(v31.numpy()[0], 2 * s3.numpy()[0, 1] * v3.numpy()[0, 1], tol=10 * tol)
1165
+ assert_np_equal(v32.numpy()[0], 2 * s3.numpy()[0, 2] * v3.numpy()[0, 2], tol=10 * tol)
1166
+
1167
+ assert_np_equal(v40.numpy()[0], 2 * s4.numpy()[0, 0] * v4.numpy()[0, 0], tol=10 * tol)
1168
+ assert_np_equal(v41.numpy()[0], 2 * s4.numpy()[0, 1] * v4.numpy()[0, 1], tol=10 * tol)
1169
+ assert_np_equal(v42.numpy()[0], 2 * s4.numpy()[0, 2] * v4.numpy()[0, 2], tol=10 * tol)
1170
+ assert_np_equal(v43.numpy()[0], 2 * s4.numpy()[0, 3] * v4.numpy()[0, 3], tol=10 * tol)
1171
+
1172
+ assert_np_equal(v50.numpy()[0], 2 * s5.numpy()[0, 0] * v5.numpy()[0, 0], tol=10 * tol)
1173
+ assert_np_equal(v51.numpy()[0], 2 * s5.numpy()[0, 1] * v5.numpy()[0, 1], tol=10 * tol)
1174
+ assert_np_equal(v52.numpy()[0], 2 * s5.numpy()[0, 2] * v5.numpy()[0, 2], tol=10 * tol)
1175
+ assert_np_equal(v53.numpy()[0], 2 * s5.numpy()[0, 3] * v5.numpy()[0, 3], tol=10 * tol)
1176
+ assert_np_equal(v54.numpy()[0], 2 * s5.numpy()[0, 4] * v5.numpy()[0, 4], tol=10 * tol)
1177
+
1178
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1179
+ scmps = np.concatenate([v.numpy()[0] for v in [s2, s3, s4, s5]])
1180
+
1181
+ if dtype in np_float_types:
1182
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1183
+ tape.backward(loss=l)
1184
+ sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1185
+ expected_grads = np.zeros_like(sgrads)
1186
+ expected_grads[i] = incmps[i] * 2
1187
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1188
+
1189
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1190
+ expected_grads = np.zeros_like(allgrads)
1191
+ expected_grads[i] = scmps[i] * 2
1192
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
1193
+
1194
+ tape.zero()
1195
+
1196
+
1197
+ def test_scalar_division(test, device, dtype, register_kernels=False):
1198
+ rng = np.random.default_rng(123)
1199
+
1200
+ tol = {
1201
+ np.float16: 5.0e-3,
1202
+ np.float32: 1.0e-6,
1203
+ np.float64: 1.0e-8,
1204
+ }.get(dtype, 0)
1205
+
1206
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1207
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1208
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1209
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1210
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1211
+
1212
+ def check_div(
1213
+ s: wp.array(dtype=wptype),
1214
+ v2: wp.array(dtype=vec2),
1215
+ v3: wp.array(dtype=vec3),
1216
+ v4: wp.array(dtype=vec4),
1217
+ v5: wp.array(dtype=vec5),
1218
+ v20: wp.array(dtype=wptype),
1219
+ v21: wp.array(dtype=wptype),
1220
+ v30: wp.array(dtype=wptype),
1221
+ v31: wp.array(dtype=wptype),
1222
+ v32: wp.array(dtype=wptype),
1223
+ v40: wp.array(dtype=wptype),
1224
+ v41: wp.array(dtype=wptype),
1225
+ v42: wp.array(dtype=wptype),
1226
+ v43: wp.array(dtype=wptype),
1227
+ v50: wp.array(dtype=wptype),
1228
+ v51: wp.array(dtype=wptype),
1229
+ v52: wp.array(dtype=wptype),
1230
+ v53: wp.array(dtype=wptype),
1231
+ v54: wp.array(dtype=wptype),
1232
+ ):
1233
+ v2result = v2[0] / s[0]
1234
+ v3result = v3[0] / s[0]
1235
+ v4result = v4[0] / s[0]
1236
+ v5result = v5[0] / s[0]
1237
+
1238
+ v20[0] = wptype(2) * v2result[0]
1239
+ v21[0] = wptype(2) * v2result[1]
1240
+
1241
+ v30[0] = wptype(2) * v3result[0]
1242
+ v31[0] = wptype(2) * v3result[1]
1243
+ v32[0] = wptype(2) * v3result[2]
1244
+
1245
+ v40[0] = wptype(2) * v4result[0]
1246
+ v41[0] = wptype(2) * v4result[1]
1247
+ v42[0] = wptype(2) * v4result[2]
1248
+ v43[0] = wptype(2) * v4result[3]
1249
+
1250
+ v50[0] = wptype(2) * v5result[0]
1251
+ v51[0] = wptype(2) * v5result[1]
1252
+ v52[0] = wptype(2) * v5result[2]
1253
+ v53[0] = wptype(2) * v5result[3]
1254
+ v54[0] = wptype(2) * v5result[4]
1255
+
1256
+ kernel = getkernel(check_div, suffix=dtype.__name__)
1257
+
1258
+ if register_kernels:
1259
+ return
1260
+
1261
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
1262
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1263
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1264
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1265
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1266
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1267
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1268
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1269
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1270
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1271
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1272
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1273
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1274
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1275
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1276
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1277
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1278
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1279
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1280
+ tape = wp.Tape()
1281
+ with tape:
1282
+ wp.launch(
1283
+ kernel,
1284
+ dim=1,
1285
+ inputs=[
1286
+ s,
1287
+ v2,
1288
+ v3,
1289
+ v4,
1290
+ v5,
1291
+ ],
1292
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1293
+ device=device,
1294
+ )
1295
+
1296
+ if dtype in np_int_types:
1297
+ assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] // (s.numpy()[0])), tol=tol)
1298
+ assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] // (s.numpy()[0])), tol=tol)
1299
+
1300
+ assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1301
+ assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1302
+ assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1303
+
1304
+ assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1305
+ assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1306
+ assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1307
+ assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] // (s.numpy()[0])), tol=10 * tol)
1308
+
1309
+ assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1310
+ assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1311
+ assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1312
+ assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] // (s.numpy()[0])), tol=10 * tol)
1313
+ assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] // (s.numpy()[0])), tol=10 * tol)
1314
+
1315
+ else:
1316
+ assert_np_equal(v20.numpy()[0], 2 * v2.numpy()[0, 0] / (s.numpy()[0]), tol=tol)
1317
+ assert_np_equal(v21.numpy()[0], 2 * v2.numpy()[0, 1] / (s.numpy()[0]), tol=tol)
1318
+
1319
+ assert_np_equal(v30.numpy()[0], 2 * v3.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1320
+ assert_np_equal(v31.numpy()[0], 2 * v3.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1321
+ assert_np_equal(v32.numpy()[0], 2 * v3.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1322
+
1323
+ assert_np_equal(v40.numpy()[0], 2 * v4.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1324
+ assert_np_equal(v41.numpy()[0], 2 * v4.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1325
+ assert_np_equal(v42.numpy()[0], 2 * v4.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1326
+ assert_np_equal(v43.numpy()[0], 2 * v4.numpy()[0, 3] / (s.numpy()[0]), tol=10 * tol)
1327
+
1328
+ assert_np_equal(v50.numpy()[0], 2 * v5.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1329
+ assert_np_equal(v51.numpy()[0], 2 * v5.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1330
+ assert_np_equal(v52.numpy()[0], 2 * v5.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1331
+ assert_np_equal(v53.numpy()[0], 2 * v5.numpy()[0, 3] / (s.numpy()[0]), tol=10 * tol)
1332
+ assert_np_equal(v54.numpy()[0], 2 * v5.numpy()[0, 4] / (s.numpy()[0]), tol=10 * tol)
1333
+
1334
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1335
+
1336
+ if dtype in np_float_types:
1337
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1338
+ tape.backward(loss=l)
1339
+ sgrad = tape.gradients[s].numpy()[0]
1340
+
1341
+ # d/ds v/s = -v/s^2
1342
+ assert_np_equal(sgrad, -2 * incmps[i] / (s.numpy()[0] * s.numpy()[0]), tol=10 * tol)
1343
+
1344
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1345
+ expected_grads = np.zeros_like(allgrads)
1346
+ expected_grads[i] = 2 / s.numpy()[0]
1347
+
1348
+ # d/dv v/s = 1/s
1349
+ assert_np_equal(allgrads, expected_grads, tol=tol)
1350
+ tape.zero()
1351
+
1352
+
1353
+ def test_cw_division(test, device, dtype, register_kernels=False):
1354
+ rng = np.random.default_rng(123)
1355
+
1356
+ tol = {
1357
+ np.float16: 1.0e-2,
1358
+ np.float32: 1.0e-6,
1359
+ np.float64: 1.0e-8,
1360
+ }.get(dtype, 0)
1361
+
1362
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1363
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1364
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1365
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1366
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1367
+
1368
+ def check_cw_div(
1369
+ s2: wp.array(dtype=vec2),
1370
+ s3: wp.array(dtype=vec3),
1371
+ s4: wp.array(dtype=vec4),
1372
+ s5: wp.array(dtype=vec5),
1373
+ v2: wp.array(dtype=vec2),
1374
+ v3: wp.array(dtype=vec3),
1375
+ v4: wp.array(dtype=vec4),
1376
+ v5: wp.array(dtype=vec5),
1377
+ v20: wp.array(dtype=wptype),
1378
+ v21: wp.array(dtype=wptype),
1379
+ v30: wp.array(dtype=wptype),
1380
+ v31: wp.array(dtype=wptype),
1381
+ v32: wp.array(dtype=wptype),
1382
+ v40: wp.array(dtype=wptype),
1383
+ v41: wp.array(dtype=wptype),
1384
+ v42: wp.array(dtype=wptype),
1385
+ v43: wp.array(dtype=wptype),
1386
+ v50: wp.array(dtype=wptype),
1387
+ v51: wp.array(dtype=wptype),
1388
+ v52: wp.array(dtype=wptype),
1389
+ v53: wp.array(dtype=wptype),
1390
+ v54: wp.array(dtype=wptype),
1391
+ ):
1392
+ v2result = wp.cw_div(v2[0], s2[0])
1393
+ v3result = wp.cw_div(v3[0], s3[0])
1394
+ v4result = wp.cw_div(v4[0], s4[0])
1395
+ v5result = wp.cw_div(v5[0], s5[0])
1396
+
1397
+ v20[0] = wptype(2) * v2result[0]
1398
+ v21[0] = wptype(2) * v2result[1]
1399
+
1400
+ v30[0] = wptype(2) * v3result[0]
1401
+ v31[0] = wptype(2) * v3result[1]
1402
+ v32[0] = wptype(2) * v3result[2]
1403
+
1404
+ v40[0] = wptype(2) * v4result[0]
1405
+ v41[0] = wptype(2) * v4result[1]
1406
+ v42[0] = wptype(2) * v4result[2]
1407
+ v43[0] = wptype(2) * v4result[3]
1408
+
1409
+ v50[0] = wptype(2) * v5result[0]
1410
+ v51[0] = wptype(2) * v5result[1]
1411
+ v52[0] = wptype(2) * v5result[2]
1412
+ v53[0] = wptype(2) * v5result[3]
1413
+ v54[0] = wptype(2) * v5result[4]
1414
+
1415
+ kernel = getkernel(check_cw_div, suffix=dtype.__name__)
1416
+
1417
+ if register_kernels:
1418
+ return
1419
+
1420
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1421
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1422
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1423
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1424
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1425
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1426
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1427
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1428
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1429
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1430
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1431
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1432
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1433
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1434
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1435
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1436
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1437
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1438
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1439
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1440
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1441
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1442
+ tape = wp.Tape()
1443
+ with tape:
1444
+ wp.launch(
1445
+ kernel,
1446
+ dim=1,
1447
+ inputs=[
1448
+ s2,
1449
+ s3,
1450
+ s4,
1451
+ s5,
1452
+ v2,
1453
+ v3,
1454
+ v4,
1455
+ v5,
1456
+ ],
1457
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1458
+ device=device,
1459
+ )
1460
+
1461
+ if dtype in np_int_types:
1462
+ assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] // s2.numpy()[0, 0]), tol=tol)
1463
+ assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] // s2.numpy()[0, 1]), tol=tol)
1464
+
1465
+ assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] // s3.numpy()[0, 0]), tol=tol)
1466
+ assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] // s3.numpy()[0, 1]), tol=tol)
1467
+ assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] // s3.numpy()[0, 2]), tol=tol)
1468
+
1469
+ assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] // s4.numpy()[0, 0]), tol=tol)
1470
+ assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] // s4.numpy()[0, 1]), tol=tol)
1471
+ assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] // s4.numpy()[0, 2]), tol=tol)
1472
+ assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] // s4.numpy()[0, 3]), tol=tol)
1473
+
1474
+ assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] // s5.numpy()[0, 0]), tol=tol)
1475
+ assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] // s5.numpy()[0, 1]), tol=tol)
1476
+ assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] // s5.numpy()[0, 2]), tol=tol)
1477
+ assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] // s5.numpy()[0, 3]), tol=tol)
1478
+ assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] // s5.numpy()[0, 4]), tol=tol)
1479
+ else:
1480
+ assert_np_equal(v20.numpy()[0], 2 * v2.numpy()[0, 0] / s2.numpy()[0, 0], tol=tol)
1481
+ assert_np_equal(v21.numpy()[0], 2 * v2.numpy()[0, 1] / s2.numpy()[0, 1], tol=tol)
1482
+
1483
+ assert_np_equal(v30.numpy()[0], 2 * v3.numpy()[0, 0] / s3.numpy()[0, 0], tol=tol)
1484
+ assert_np_equal(v31.numpy()[0], 2 * v3.numpy()[0, 1] / s3.numpy()[0, 1], tol=tol)
1485
+ assert_np_equal(v32.numpy()[0], 2 * v3.numpy()[0, 2] / s3.numpy()[0, 2], tol=tol)
1486
+
1487
+ assert_np_equal(v40.numpy()[0], 2 * v4.numpy()[0, 0] / s4.numpy()[0, 0], tol=tol)
1488
+ assert_np_equal(v41.numpy()[0], 2 * v4.numpy()[0, 1] / s4.numpy()[0, 1], tol=tol)
1489
+ assert_np_equal(v42.numpy()[0], 2 * v4.numpy()[0, 2] / s4.numpy()[0, 2], tol=tol)
1490
+ assert_np_equal(v43.numpy()[0], 2 * v4.numpy()[0, 3] / s4.numpy()[0, 3], tol=tol)
1491
+
1492
+ assert_np_equal(v50.numpy()[0], 2 * v5.numpy()[0, 0] / s5.numpy()[0, 0], tol=tol)
1493
+ assert_np_equal(v51.numpy()[0], 2 * v5.numpy()[0, 1] / s5.numpy()[0, 1], tol=tol)
1494
+ assert_np_equal(v52.numpy()[0], 2 * v5.numpy()[0, 2] / s5.numpy()[0, 2], tol=tol)
1495
+ assert_np_equal(v53.numpy()[0], 2 * v5.numpy()[0, 3] / s5.numpy()[0, 3], tol=tol)
1496
+ assert_np_equal(v54.numpy()[0], 2 * v5.numpy()[0, 4] / s5.numpy()[0, 4], tol=tol)
1497
+
1498
+ if dtype in np_float_types:
1499
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1500
+ scmps = np.concatenate([v.numpy()[0] for v in [s2, s3, s4, s5]])
1501
+
1502
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1503
+ tape.backward(loss=l)
1504
+ sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1505
+ expected_grads = np.zeros_like(sgrads)
1506
+
1507
+ # d/ds v/s = -v/s^2
1508
+ expected_grads[i] = -incmps[i] * 2 / (scmps[i] * scmps[i])
1509
+ assert_np_equal(sgrads, expected_grads, tol=20 * tol)
1510
+
1511
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1512
+ expected_grads = np.zeros_like(allgrads)
1513
+
1514
+ # d/dv v/s = 1/s
1515
+ expected_grads[i] = 2 / scmps[i]
1516
+ assert_np_equal(allgrads, expected_grads, tol=tol)
1517
+
1518
+ tape.zero()
1519
+
1520
+
1521
+ def test_addition(test, device, dtype, register_kernels=False):
1522
+ rng = np.random.default_rng(123)
1523
+
1524
+ tol = {
1525
+ np.float16: 5.0e-3,
1526
+ np.float32: 1.0e-6,
1527
+ np.float64: 1.0e-8,
1528
+ }.get(dtype, 0)
1529
+
1530
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1531
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1532
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1533
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1534
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1535
+
1536
+ def check_add(
1537
+ s2: wp.array(dtype=vec2),
1538
+ s3: wp.array(dtype=vec3),
1539
+ s4: wp.array(dtype=vec4),
1540
+ s5: wp.array(dtype=vec5),
1541
+ v2: wp.array(dtype=vec2),
1542
+ v3: wp.array(dtype=vec3),
1543
+ v4: wp.array(dtype=vec4),
1544
+ v5: wp.array(dtype=vec5),
1545
+ v20: wp.array(dtype=wptype),
1546
+ v21: wp.array(dtype=wptype),
1547
+ v30: wp.array(dtype=wptype),
1548
+ v31: wp.array(dtype=wptype),
1549
+ v32: wp.array(dtype=wptype),
1550
+ v40: wp.array(dtype=wptype),
1551
+ v41: wp.array(dtype=wptype),
1552
+ v42: wp.array(dtype=wptype),
1553
+ v43: wp.array(dtype=wptype),
1554
+ v50: wp.array(dtype=wptype),
1555
+ v51: wp.array(dtype=wptype),
1556
+ v52: wp.array(dtype=wptype),
1557
+ v53: wp.array(dtype=wptype),
1558
+ v54: wp.array(dtype=wptype),
1559
+ ):
1560
+ v2result = v2[0] + s2[0]
1561
+ v3result = v3[0] + s3[0]
1562
+ v4result = v4[0] + s4[0]
1563
+ v5result = v5[0] + s5[0]
1564
+
1565
+ v20[0] = wptype(2) * v2result[0]
1566
+ v21[0] = wptype(2) * v2result[1]
1567
+
1568
+ v30[0] = wptype(2) * v3result[0]
1569
+ v31[0] = wptype(2) * v3result[1]
1570
+ v32[0] = wptype(2) * v3result[2]
1571
+
1572
+ v40[0] = wptype(2) * v4result[0]
1573
+ v41[0] = wptype(2) * v4result[1]
1574
+ v42[0] = wptype(2) * v4result[2]
1575
+ v43[0] = wptype(2) * v4result[3]
1576
+
1577
+ v50[0] = wptype(2) * v5result[0]
1578
+ v51[0] = wptype(2) * v5result[1]
1579
+ v52[0] = wptype(2) * v5result[2]
1580
+ v53[0] = wptype(2) * v5result[3]
1581
+ v54[0] = wptype(2) * v5result[4]
1582
+
1583
+ kernel = getkernel(check_add, suffix=dtype.__name__)
1584
+
1585
+ if register_kernels:
1586
+ return
1587
+
1588
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1589
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1590
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1591
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1592
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1593
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1594
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1595
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1596
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1597
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1598
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1599
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1600
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1601
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1602
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1603
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1604
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1605
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1606
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1607
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1608
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1609
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1610
+ tape = wp.Tape()
1611
+ with tape:
1612
+ wp.launch(
1613
+ kernel,
1614
+ dim=1,
1615
+ inputs=[
1616
+ s2,
1617
+ s3,
1618
+ s4,
1619
+ s5,
1620
+ v2,
1621
+ v3,
1622
+ v4,
1623
+ v5,
1624
+ ],
1625
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1626
+ device=device,
1627
+ )
1628
+
1629
+ assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] + s2.numpy()[0, 0]), tol=tol)
1630
+ assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] + s2.numpy()[0, 1]), tol=tol)
1631
+
1632
+ assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] + s3.numpy()[0, 0]), tol=tol)
1633
+ assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] + s3.numpy()[0, 1]), tol=tol)
1634
+ assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] + s3.numpy()[0, 2]), tol=tol)
1635
+
1636
+ assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] + s4.numpy()[0, 0]), tol=tol)
1637
+ assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] + s4.numpy()[0, 1]), tol=tol)
1638
+ assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] + s4.numpy()[0, 2]), tol=tol)
1639
+ assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] + s4.numpy()[0, 3]), tol=tol)
1640
+
1641
+ assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] + s5.numpy()[0, 0]), tol=tol)
1642
+ assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] + s5.numpy()[0, 1]), tol=tol)
1643
+ assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] + s5.numpy()[0, 2]), tol=tol)
1644
+ assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] + s5.numpy()[0, 3]), tol=tol)
1645
+ assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] + s5.numpy()[0, 4]), tol=2 * tol)
1646
+
1647
+ if dtype in np_float_types:
1648
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1649
+ tape.backward(loss=l)
1650
+ sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1651
+ expected_grads = np.zeros_like(sgrads)
1652
+
1653
+ expected_grads[i] = 2
1654
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1655
+
1656
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1657
+ assert_np_equal(allgrads, expected_grads, tol=tol)
1658
+
1659
+ tape.zero()
1660
+
1661
+
1662
+ def test_dotproduct(test, device, dtype, register_kernels=False):
1663
+ rng = np.random.default_rng(123)
1664
+
1665
+ tol = {
1666
+ np.float16: 1.0e-2,
1667
+ np.float32: 1.0e-6,
1668
+ np.float64: 1.0e-8,
1669
+ }.get(dtype, 0)
1670
+
1671
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1672
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1673
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1674
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1675
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1676
+
1677
+ def check_dot(
1678
+ s2: wp.array(dtype=vec2),
1679
+ s3: wp.array(dtype=vec3),
1680
+ s4: wp.array(dtype=vec4),
1681
+ s5: wp.array(dtype=vec5),
1682
+ v2: wp.array(dtype=vec2),
1683
+ v3: wp.array(dtype=vec3),
1684
+ v4: wp.array(dtype=vec4),
1685
+ v5: wp.array(dtype=vec5),
1686
+ dot2: wp.array(dtype=wptype),
1687
+ dot3: wp.array(dtype=wptype),
1688
+ dot4: wp.array(dtype=wptype),
1689
+ dot5: wp.array(dtype=wptype),
1690
+ ):
1691
+ dot2[0] = wptype(2) * wp.dot(v2[0], s2[0])
1692
+ dot3[0] = wptype(2) * wp.dot(v3[0], s3[0])
1693
+ dot4[0] = wptype(2) * wp.dot(v4[0], s4[0])
1694
+ dot5[0] = wptype(2) * wp.dot(v5[0], s5[0])
1695
+
1696
+ kernel = getkernel(check_dot, suffix=dtype.__name__)
1697
+
1698
+ if register_kernels:
1699
+ return
1700
+
1701
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1702
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1703
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1704
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1705
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1706
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1707
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1708
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1709
+ dot2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1710
+ dot3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1711
+ dot4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1712
+ dot5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1713
+ tape = wp.Tape()
1714
+ with tape:
1715
+ wp.launch(
1716
+ kernel,
1717
+ dim=1,
1718
+ inputs=[
1719
+ s2,
1720
+ s3,
1721
+ s4,
1722
+ s5,
1723
+ v2,
1724
+ v3,
1725
+ v4,
1726
+ v5,
1727
+ ],
1728
+ outputs=[dot2, dot3, dot4, dot5],
1729
+ device=device,
1730
+ )
1731
+
1732
+ assert_np_equal(dot2.numpy()[0], 2.0 * (v2.numpy() * s2.numpy()).sum(), tol=10 * tol)
1733
+ assert_np_equal(dot3.numpy()[0], 2.0 * (v3.numpy() * s3.numpy()).sum(), tol=10 * tol)
1734
+ assert_np_equal(dot4.numpy()[0], 2.0 * (v4.numpy() * s4.numpy()).sum(), tol=10 * tol)
1735
+ assert_np_equal(dot5.numpy()[0], 2.0 * (v5.numpy() * s5.numpy()).sum(), tol=10 * tol)
1736
+
1737
+ if dtype in np_float_types:
1738
+ tape.backward(loss=dot2)
1739
+ sgrads = tape.gradients[s2].numpy()[0]
1740
+ expected_grads = 2.0 * v2.numpy()[0]
1741
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1742
+
1743
+ vgrads = tape.gradients[v2].numpy()[0]
1744
+ expected_grads = 2.0 * s2.numpy()[0]
1745
+ assert_np_equal(vgrads, expected_grads, tol=tol)
1746
+
1747
+ tape.zero()
1748
+
1749
+ tape.backward(loss=dot3)
1750
+ sgrads = tape.gradients[s3].numpy()[0]
1751
+ expected_grads = 2.0 * v3.numpy()[0]
1752
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1753
+
1754
+ vgrads = tape.gradients[v3].numpy()[0]
1755
+ expected_grads = 2.0 * s3.numpy()[0]
1756
+ assert_np_equal(vgrads, expected_grads, tol=tol)
1757
+
1758
+ tape.zero()
1759
+
1760
+ tape.backward(loss=dot4)
1761
+ sgrads = tape.gradients[s4].numpy()[0]
1762
+ expected_grads = 2.0 * v4.numpy()[0]
1763
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1764
+
1765
+ vgrads = tape.gradients[v4].numpy()[0]
1766
+ expected_grads = 2.0 * s4.numpy()[0]
1767
+ assert_np_equal(vgrads, expected_grads, tol=tol)
1768
+
1769
+ tape.zero()
1770
+
1771
+ tape.backward(loss=dot5)
1772
+ sgrads = tape.gradients[s5].numpy()[0]
1773
+ expected_grads = 2.0 * v5.numpy()[0]
1774
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1775
+
1776
+ vgrads = tape.gradients[v5].numpy()[0]
1777
+ expected_grads = 2.0 * s5.numpy()[0]
1778
+ assert_np_equal(vgrads, expected_grads, tol=10 * tol)
1779
+
1780
+ tape.zero()
1781
+
1782
+
1783
+ def test_equivalent_types(test, device, dtype, register_kernels=False):
1784
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1785
+
1786
+ # vector types
1787
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1788
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1789
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1790
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1791
+
1792
+ # vector types equivalent to the above
1793
+ vec2_equiv = wp.types.vector(length=2, dtype=wptype)
1794
+ vec3_equiv = wp.types.vector(length=3, dtype=wptype)
1795
+ vec4_equiv = wp.types.vector(length=4, dtype=wptype)
1796
+ vec5_equiv = wp.types.vector(length=5, dtype=wptype)
1797
+
1798
+ # declare kernel with original types
1799
+ def check_equivalence(
1800
+ v2: vec2,
1801
+ v3: vec3,
1802
+ v4: vec4,
1803
+ v5: vec5,
1804
+ ):
1805
+ wp.expect_eq(v2, vec2(wptype(1), wptype(2)))
1806
+ wp.expect_eq(v3, vec3(wptype(1), wptype(2), wptype(3)))
1807
+ wp.expect_eq(v4, vec4(wptype(1), wptype(2), wptype(3), wptype(4)))
1808
+ wp.expect_eq(v5, vec5(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1809
+
1810
+ wp.expect_eq(v2, vec2_equiv(wptype(1), wptype(2)))
1811
+ wp.expect_eq(v3, vec3_equiv(wptype(1), wptype(2), wptype(3)))
1812
+ wp.expect_eq(v4, vec4_equiv(wptype(1), wptype(2), wptype(3), wptype(4)))
1813
+ wp.expect_eq(v5, vec5_equiv(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1814
+
1815
+ kernel = getkernel(check_equivalence, suffix=dtype.__name__)
1816
+
1817
+ if register_kernels:
1818
+ return
1819
+
1820
+ # call kernel with equivalent types
1821
+ v2 = vec2_equiv(1, 2)
1822
+ v3 = vec3_equiv(1, 2, 3)
1823
+ v4 = vec4_equiv(1, 2, 3, 4)
1824
+ v5 = vec5_equiv(1, 2, 3, 4, 5)
1825
+
1826
+ wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5], device=device)
1827
+
1828
+
1829
+ def test_conversions(test, device, dtype, register_kernels=False):
1830
+ def check_vectors_equal(
1831
+ v0: wp.vec3,
1832
+ v1: wp.vec3,
1833
+ v2: wp.vec3,
1834
+ v3: wp.vec3,
1835
+ ):
1836
+ wp.expect_eq(v1, v0)
1837
+ wp.expect_eq(v2, v0)
1838
+ wp.expect_eq(v3, v0)
1839
+
1840
+ kernel = getkernel(check_vectors_equal, suffix=dtype.__name__)
1841
+
1842
+ if register_kernels:
1843
+ return
1844
+
1845
+ v0 = wp.vec3(1, 2, 3)
1846
+
1847
+ # test explicit conversions - constructing vectors from different containers
1848
+ v1 = wp.vec3((1, 2, 3))
1849
+ v2 = wp.vec3([1, 2, 3])
1850
+ v3 = wp.vec3(np.array([1, 2, 3], dtype=dtype))
1851
+
1852
+ wp.launch(kernel, dim=1, inputs=[v0, v1, v2, v3], device=device)
1853
+
1854
+ # test implicit conversions - passing different containers as vectors to wp.launch()
1855
+ v1 = (1, 2, 3)
1856
+ v2 = [1, 2, 3]
1857
+ v3 = np.array([1, 2, 3], dtype=dtype)
1858
+
1859
+ wp.launch(kernel, dim=1, inputs=[v0, v1, v2, v3], device=device)
1860
+
1861
+
1862
+ def test_constants(test, device, dtype, register_kernels=False):
1863
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1864
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1865
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1866
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1867
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1868
+
1869
+ cv2 = wp.constant(vec2(1, 2))
1870
+ cv3 = wp.constant(vec3(1, 2, 3))
1871
+ cv4 = wp.constant(vec4(1, 2, 3, 4))
1872
+ cv5 = wp.constant(vec5(1, 2, 3, 4, 5))
1873
+
1874
+ def check_vector_constants():
1875
+ wp.expect_eq(cv2, vec2(wptype(1), wptype(2)))
1876
+ wp.expect_eq(cv3, vec3(wptype(1), wptype(2), wptype(3)))
1877
+ wp.expect_eq(cv4, vec4(wptype(1), wptype(2), wptype(3), wptype(4)))
1878
+ wp.expect_eq(cv5, vec5(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1879
+
1880
+ kernel = getkernel(check_vector_constants, suffix=dtype.__name__)
1881
+
1882
+ if register_kernels:
1883
+ return
1884
+
1885
+ wp.launch(kernel, dim=1, inputs=[])
1886
+
1887
+
1888
+ def test_minmax(test, device, dtype, register_kernels=False):
1889
+ rng = np.random.default_rng(123)
1890
+
1891
+ # \TODO: not quite sure why, but the numbers are off for 16 bit float
1892
+ # on the cpu (but not cuda). This is probably just the sketchy float16
1893
+ # arithmetic I implemented to get all this stuff working, so
1894
+ # hopefully that can be fixed when we do that correctly.
1895
+ tol = {
1896
+ np.float16: 1.0e-2,
1897
+ }.get(dtype, 0)
1898
+
1899
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1900
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1901
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1902
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1903
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1904
+
1905
+ # \TODO: Also not quite sure why: this kernel compiles incredibly
1906
+ # slowly though...
1907
+ def check_vec_min_max(
1908
+ a: wp.array(dtype=wptype, ndim=2),
1909
+ b: wp.array(dtype=wptype, ndim=2),
1910
+ mins: wp.array(dtype=wptype, ndim=2),
1911
+ maxs: wp.array(dtype=wptype, ndim=2),
1912
+ ):
1913
+ for i in range(10):
1914
+ # multiplying by 2 so we've got something to backpropagate:
1915
+ a2read = vec2(a[i, 0], a[i, 1])
1916
+ b2read = vec2(b[i, 0], b[i, 1])
1917
+ c2 = wptype(2) * wp.min(a2read, b2read)
1918
+ d2 = wptype(2) * wp.max(a2read, b2read)
1919
+
1920
+ a3read = vec3(a[i, 2], a[i, 3], a[i, 4])
1921
+ b3read = vec3(b[i, 2], b[i, 3], b[i, 4])
1922
+ c3 = wptype(2) * wp.min(a3read, b3read)
1923
+ d3 = wptype(2) * wp.max(a3read, b3read)
1924
+
1925
+ a4read = vec4(a[i, 5], a[i, 6], a[i, 7], a[i, 8])
1926
+ b4read = vec4(b[i, 5], b[i, 6], b[i, 7], b[i, 8])
1927
+ c4 = wptype(2) * wp.min(a4read, b4read)
1928
+ d4 = wptype(2) * wp.max(a4read, b4read)
1929
+
1930
+ a5read = vec5(a[i, 9], a[i, 10], a[i, 11], a[i, 12], a[i, 13])
1931
+ b5read = vec5(b[i, 9], b[i, 10], b[i, 11], b[i, 12], b[i, 13])
1932
+ c5 = wptype(2) * wp.min(a5read, b5read)
1933
+ d5 = wptype(2) * wp.max(a5read, b5read)
1934
+
1935
+ mins[i, 0] = c2[0]
1936
+ mins[i, 1] = c2[1]
1937
+
1938
+ mins[i, 2] = c3[0]
1939
+ mins[i, 3] = c3[1]
1940
+ mins[i, 4] = c3[2]
1941
+
1942
+ mins[i, 5] = c4[0]
1943
+ mins[i, 6] = c4[1]
1944
+ mins[i, 7] = c4[2]
1945
+ mins[i, 8] = c4[3]
1946
+
1947
+ mins[i, 9] = c5[0]
1948
+ mins[i, 10] = c5[1]
1949
+ mins[i, 11] = c5[2]
1950
+ mins[i, 12] = c5[3]
1951
+ mins[i, 13] = c5[4]
1952
+
1953
+ maxs[i, 0] = d2[0]
1954
+ maxs[i, 1] = d2[1]
1955
+
1956
+ maxs[i, 2] = d3[0]
1957
+ maxs[i, 3] = d3[1]
1958
+ maxs[i, 4] = d3[2]
1959
+
1960
+ maxs[i, 5] = d4[0]
1961
+ maxs[i, 6] = d4[1]
1962
+ maxs[i, 7] = d4[2]
1963
+ maxs[i, 8] = d4[3]
1964
+
1965
+ maxs[i, 9] = d5[0]
1966
+ maxs[i, 10] = d5[1]
1967
+ maxs[i, 11] = d5[2]
1968
+ maxs[i, 12] = d5[3]
1969
+ maxs[i, 13] = d5[4]
1970
+
1971
+ kernel = getkernel(check_vec_min_max, suffix=dtype.__name__)
1972
+ output_select_kernel = get_select_kernel2(wptype)
1973
+
1974
+ if register_kernels:
1975
+ return
1976
+
1977
+ a = wp.array(randvals(rng, (10, 14), dtype), dtype=wptype, requires_grad=True, device=device)
1978
+ b = wp.array(randvals(rng, (10, 14), dtype), dtype=wptype, requires_grad=True, device=device)
1979
+
1980
+ mins = wp.zeros((10, 14), dtype=wptype, requires_grad=True, device=device)
1981
+ maxs = wp.zeros((10, 14), dtype=wptype, requires_grad=True, device=device)
1982
+
1983
+ tape = wp.Tape()
1984
+ with tape:
1985
+ wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
1986
+
1987
+ assert_np_equal(mins.numpy(), 2 * np.minimum(a.numpy(), b.numpy()), tol=tol)
1988
+ assert_np_equal(maxs.numpy(), 2 * np.maximum(a.numpy(), b.numpy()), tol=tol)
1989
+
1990
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1991
+ if dtype in np_float_types:
1992
+ for i in range(10):
1993
+ for j in range(14):
1994
+ tape = wp.Tape()
1995
+ with tape:
1996
+ wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
1997
+ wp.launch(output_select_kernel, dim=1, inputs=[mins, i, j], outputs=[out], device=device)
1998
+
1999
+ tape.backward(loss=out)
2000
+ expected = np.zeros_like(a.numpy())
2001
+ expected[i, j] = 2 if (a.numpy()[i, j] < b.numpy()[i, j]) else 0
2002
+ assert_np_equal(tape.gradients[a].numpy(), expected, tol=tol)
2003
+ expected[i, j] = 2 if (b.numpy()[i, j] < a.numpy()[i, j]) else 0
2004
+ assert_np_equal(tape.gradients[b].numpy(), expected, tol=tol)
2005
+ tape.zero()
2006
+
2007
+ tape = wp.Tape()
2008
+ with tape:
2009
+ wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
2010
+ wp.launch(output_select_kernel, dim=1, inputs=[maxs, i, j], outputs=[out], device=device)
2011
+
2012
+ tape.backward(loss=out)
2013
+ expected = np.zeros_like(a.numpy())
2014
+ expected[i, j] = 2 if (a.numpy()[i, j] > b.numpy()[i, j]) else 0
2015
+ assert_np_equal(tape.gradients[a].numpy(), expected, tol=tol)
2016
+ expected[i, j] = 2 if (b.numpy()[i, j] > a.numpy()[i, j]) else 0
2017
+ assert_np_equal(tape.gradients[b].numpy(), expected, tol=tol)
2018
+ tape.zero()
2019
+
2020
+
2021
+ devices = get_test_devices()
2022
+
2023
+
2024
+ class TestVecScalarOps(unittest.TestCase):
2025
+ pass
2026
+
2027
+
2028
+ for dtype in np_scalar_types:
2029
+ add_function_test(TestVecScalarOps, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
2030
+ add_function_test(TestVecScalarOps, f"test_components_{dtype.__name__}", test_components, devices=None, dtype=dtype)
2031
+ add_function_test(
2032
+ TestVecScalarOps, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2033
+ )
2034
+ add_function_test_register_kernel(
2035
+ TestVecScalarOps, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
2036
+ )
2037
+ add_function_test_register_kernel(
2038
+ TestVecScalarOps,
2039
+ f"test_anon_type_instance_{dtype.__name__}",
2040
+ test_anon_type_instance,
2041
+ devices=devices,
2042
+ dtype=dtype,
2043
+ )
2044
+ add_function_test_register_kernel(
2045
+ TestVecScalarOps, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2046
+ )
2047
+ add_function_test_register_kernel(
2048
+ TestVecScalarOps, f"test_equality_{dtype.__name__}", test_equality, devices=devices, dtype=dtype
2049
+ )
2050
+ add_function_test_register_kernel(
2051
+ TestVecScalarOps,
2052
+ f"test_scalar_multiplication_{dtype.__name__}",
2053
+ test_scalar_multiplication,
2054
+ devices=devices,
2055
+ dtype=dtype,
2056
+ )
2057
+ add_function_test_register_kernel(
2058
+ TestVecScalarOps,
2059
+ f"test_scalar_multiplication_rightmul_{dtype.__name__}",
2060
+ test_scalar_multiplication_rightmul,
2061
+ devices=devices,
2062
+ dtype=dtype,
2063
+ )
2064
+ add_function_test_register_kernel(
2065
+ TestVecScalarOps,
2066
+ f"test_cw_multiplication_{dtype.__name__}",
2067
+ test_cw_multiplication,
2068
+ devices=devices,
2069
+ dtype=dtype,
2070
+ )
2071
+ add_function_test_register_kernel(
2072
+ TestVecScalarOps, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2073
+ )
2074
+ add_function_test_register_kernel(
2075
+ TestVecScalarOps, f"test_cw_division_{dtype.__name__}", test_cw_division, devices=devices, dtype=dtype
2076
+ )
2077
+ add_function_test_register_kernel(
2078
+ TestVecScalarOps, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2079
+ )
2080
+ add_function_test_register_kernel(
2081
+ TestVecScalarOps, f"test_dotproduct_{dtype.__name__}", test_dotproduct, devices=devices, dtype=dtype
2082
+ )
2083
+ add_function_test_register_kernel(
2084
+ TestVecScalarOps, f"test_equivalent_types_{dtype.__name__}", test_equivalent_types, devices=devices, dtype=dtype
2085
+ )
2086
+ add_function_test_register_kernel(
2087
+ TestVecScalarOps, f"test_conversions_{dtype.__name__}", test_conversions, devices=devices, dtype=dtype
2088
+ )
2089
+ add_function_test_register_kernel(
2090
+ TestVecScalarOps, f"test_constants_{dtype.__name__}", test_constants, devices=devices, dtype=dtype
2091
+ )
2092
+
2093
+ # the kernels in this test compile incredibly slowly...
2094
+ # add_function_test_register_kernel(TestVecScalarOps, f"test_minmax_{dtype.__name__}", test_minmax, devices=devices, dtype=dtype)
2095
+
2096
+
2097
+ if __name__ == "__main__":
2098
+ wp.build.clear_kernel_cache()
2099
+ unittest.main(verbosity=2, failfast=True)