warp-lang 1.0.1__py3-none-manylinux2014_x86_64.whl → 1.1.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +108 -97
  2. warp/__init__.pyi +1 -1
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +115 -113
  6. warp/build_dll.py +383 -375
  7. warp/builtins.py +3425 -3354
  8. warp/codegen.py +2878 -2792
  9. warp/config.py +40 -36
  10. warp/constants.py +45 -45
  11. warp/context.py +5194 -5102
  12. warp/dlpack.py +442 -442
  13. warp/examples/__init__.py +16 -16
  14. warp/examples/assets/bear.usd +0 -0
  15. warp/examples/assets/bunny.usd +0 -0
  16. warp/examples/assets/cartpole.urdf +110 -110
  17. warp/examples/assets/crazyflie.usd +0 -0
  18. warp/examples/assets/cube.usd +0 -0
  19. warp/examples/assets/nv_ant.xml +92 -92
  20. warp/examples/assets/nv_humanoid.xml +183 -183
  21. warp/examples/assets/quadruped.urdf +267 -267
  22. warp/examples/assets/rocks.nvdb +0 -0
  23. warp/examples/assets/rocks.usd +0 -0
  24. warp/examples/assets/sphere.usd +0 -0
  25. warp/examples/benchmarks/benchmark_api.py +383 -383
  26. warp/examples/benchmarks/benchmark_cloth.py +278 -279
  27. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
  28. warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
  29. warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
  30. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
  31. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
  32. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
  33. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
  34. warp/examples/benchmarks/benchmark_launches.py +295 -295
  35. warp/examples/browse.py +29 -28
  36. warp/examples/core/example_dem.py +234 -221
  37. warp/examples/core/example_fluid.py +293 -267
  38. warp/examples/core/example_graph_capture.py +144 -129
  39. warp/examples/core/example_marching_cubes.py +188 -176
  40. warp/examples/core/example_mesh.py +174 -154
  41. warp/examples/core/example_mesh_intersect.py +205 -193
  42. warp/examples/core/example_nvdb.py +176 -169
  43. warp/examples/core/example_raycast.py +105 -89
  44. warp/examples/core/example_raymarch.py +199 -178
  45. warp/examples/core/example_render_opengl.py +185 -141
  46. warp/examples/core/example_sph.py +405 -389
  47. warp/examples/core/example_torch.py +222 -181
  48. warp/examples/core/example_wave.py +263 -249
  49. warp/examples/fem/bsr_utils.py +378 -380
  50. warp/examples/fem/example_apic_fluid.py +407 -391
  51. warp/examples/fem/example_convection_diffusion.py +182 -168
  52. warp/examples/fem/example_convection_diffusion_dg.py +219 -209
  53. warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
  54. warp/examples/fem/example_deformed_geometry.py +177 -159
  55. warp/examples/fem/example_diffusion.py +201 -173
  56. warp/examples/fem/example_diffusion_3d.py +177 -152
  57. warp/examples/fem/example_diffusion_mgpu.py +221 -214
  58. warp/examples/fem/example_mixed_elasticity.py +244 -222
  59. warp/examples/fem/example_navier_stokes.py +259 -243
  60. warp/examples/fem/example_stokes.py +220 -192
  61. warp/examples/fem/example_stokes_transfer.py +265 -249
  62. warp/examples/fem/mesh_utils.py +133 -109
  63. warp/examples/fem/plot_utils.py +292 -287
  64. warp/examples/optim/example_bounce.py +260 -248
  65. warp/examples/optim/example_cloth_throw.py +222 -210
  66. warp/examples/optim/example_diffray.py +566 -535
  67. warp/examples/optim/example_drone.py +864 -835
  68. warp/examples/optim/example_inverse_kinematics.py +176 -169
  69. warp/examples/optim/example_inverse_kinematics_torch.py +185 -170
  70. warp/examples/optim/example_spring_cage.py +239 -234
  71. warp/examples/optim/example_trajectory.py +223 -201
  72. warp/examples/optim/example_walker.py +306 -292
  73. warp/examples/sim/example_cartpole.py +139 -128
  74. warp/examples/sim/example_cloth.py +196 -184
  75. warp/examples/sim/example_granular.py +124 -113
  76. warp/examples/sim/example_granular_collision_sdf.py +197 -185
  77. warp/examples/sim/example_jacobian_ik.py +236 -213
  78. warp/examples/sim/example_particle_chain.py +118 -106
  79. warp/examples/sim/example_quadruped.py +193 -179
  80. warp/examples/sim/example_rigid_chain.py +197 -189
  81. warp/examples/sim/example_rigid_contact.py +189 -176
  82. warp/examples/sim/example_rigid_force.py +127 -126
  83. warp/examples/sim/example_rigid_gyroscopic.py +109 -97
  84. warp/examples/sim/example_rigid_soft_contact.py +134 -124
  85. warp/examples/sim/example_soft_body.py +190 -178
  86. warp/fabric.py +337 -335
  87. warp/fem/__init__.py +60 -27
  88. warp/fem/cache.py +401 -388
  89. warp/fem/dirichlet.py +178 -179
  90. warp/fem/domain.py +262 -263
  91. warp/fem/field/__init__.py +100 -101
  92. warp/fem/field/field.py +148 -149
  93. warp/fem/field/nodal_field.py +298 -299
  94. warp/fem/field/restriction.py +22 -21
  95. warp/fem/field/test.py +180 -181
  96. warp/fem/field/trial.py +183 -183
  97. warp/fem/geometry/__init__.py +15 -19
  98. warp/fem/geometry/closest_point.py +69 -70
  99. warp/fem/geometry/deformed_geometry.py +270 -271
  100. warp/fem/geometry/element.py +744 -744
  101. warp/fem/geometry/geometry.py +184 -186
  102. warp/fem/geometry/grid_2d.py +380 -373
  103. warp/fem/geometry/grid_3d.py +441 -435
  104. warp/fem/geometry/hexmesh.py +953 -953
  105. warp/fem/geometry/partition.py +374 -376
  106. warp/fem/geometry/quadmesh_2d.py +532 -532
  107. warp/fem/geometry/tetmesh.py +840 -840
  108. warp/fem/geometry/trimesh_2d.py +577 -577
  109. warp/fem/integrate.py +1630 -1615
  110. warp/fem/operator.py +190 -191
  111. warp/fem/polynomial.py +214 -213
  112. warp/fem/quadrature/__init__.py +2 -2
  113. warp/fem/quadrature/pic_quadrature.py +243 -245
  114. warp/fem/quadrature/quadrature.py +295 -294
  115. warp/fem/space/__init__.py +294 -292
  116. warp/fem/space/basis_space.py +488 -489
  117. warp/fem/space/collocated_function_space.py +100 -105
  118. warp/fem/space/dof_mapper.py +236 -236
  119. warp/fem/space/function_space.py +148 -145
  120. warp/fem/space/grid_2d_function_space.py +267 -267
  121. warp/fem/space/grid_3d_function_space.py +305 -306
  122. warp/fem/space/hexmesh_function_space.py +350 -352
  123. warp/fem/space/partition.py +350 -350
  124. warp/fem/space/quadmesh_2d_function_space.py +368 -369
  125. warp/fem/space/restriction.py +158 -160
  126. warp/fem/space/shape/__init__.py +13 -15
  127. warp/fem/space/shape/cube_shape_function.py +738 -738
  128. warp/fem/space/shape/shape_function.py +102 -103
  129. warp/fem/space/shape/square_shape_function.py +611 -611
  130. warp/fem/space/shape/tet_shape_function.py +565 -567
  131. warp/fem/space/shape/triangle_shape_function.py +429 -429
  132. warp/fem/space/tetmesh_function_space.py +294 -292
  133. warp/fem/space/topology.py +297 -295
  134. warp/fem/space/trimesh_2d_function_space.py +223 -221
  135. warp/fem/types.py +77 -77
  136. warp/fem/utils.py +495 -495
  137. warp/jax.py +166 -141
  138. warp/jax_experimental.py +341 -339
  139. warp/native/array.h +1072 -1025
  140. warp/native/builtin.h +1560 -1560
  141. warp/native/bvh.cpp +398 -398
  142. warp/native/bvh.cu +525 -525
  143. warp/native/bvh.h +429 -429
  144. warp/native/clang/clang.cpp +495 -464
  145. warp/native/crt.cpp +31 -31
  146. warp/native/crt.h +334 -334
  147. warp/native/cuda_crt.h +1049 -1049
  148. warp/native/cuda_util.cpp +549 -540
  149. warp/native/cuda_util.h +288 -203
  150. warp/native/cutlass_gemm.cpp +34 -34
  151. warp/native/cutlass_gemm.cu +372 -372
  152. warp/native/error.cpp +66 -66
  153. warp/native/error.h +27 -27
  154. warp/native/fabric.h +228 -228
  155. warp/native/hashgrid.cpp +301 -278
  156. warp/native/hashgrid.cu +78 -77
  157. warp/native/hashgrid.h +227 -227
  158. warp/native/initializer_array.h +32 -32
  159. warp/native/intersect.h +1204 -1204
  160. warp/native/intersect_adj.h +365 -365
  161. warp/native/intersect_tri.h +322 -322
  162. warp/native/marching.cpp +2 -2
  163. warp/native/marching.cu +497 -497
  164. warp/native/marching.h +2 -2
  165. warp/native/mat.h +1498 -1498
  166. warp/native/matnn.h +333 -333
  167. warp/native/mesh.cpp +203 -203
  168. warp/native/mesh.cu +293 -293
  169. warp/native/mesh.h +1887 -1887
  170. warp/native/nanovdb/NanoVDB.h +4782 -4782
  171. warp/native/nanovdb/PNanoVDB.h +2553 -2553
  172. warp/native/nanovdb/PNanoVDBWrite.h +294 -294
  173. warp/native/noise.h +850 -850
  174. warp/native/quat.h +1084 -1084
  175. warp/native/rand.h +299 -299
  176. warp/native/range.h +108 -108
  177. warp/native/reduce.cpp +156 -156
  178. warp/native/reduce.cu +348 -348
  179. warp/native/runlength_encode.cpp +61 -61
  180. warp/native/runlength_encode.cu +46 -46
  181. warp/native/scan.cpp +30 -30
  182. warp/native/scan.cu +36 -36
  183. warp/native/scan.h +7 -7
  184. warp/native/solid_angle.h +442 -442
  185. warp/native/sort.cpp +94 -94
  186. warp/native/sort.cu +97 -97
  187. warp/native/sort.h +14 -14
  188. warp/native/sparse.cpp +337 -337
  189. warp/native/sparse.cu +544 -544
  190. warp/native/spatial.h +630 -630
  191. warp/native/svd.h +562 -562
  192. warp/native/temp_buffer.h +30 -30
  193. warp/native/vec.h +1132 -1132
  194. warp/native/volume.cpp +297 -297
  195. warp/native/volume.cu +32 -32
  196. warp/native/volume.h +538 -538
  197. warp/native/volume_builder.cu +425 -425
  198. warp/native/volume_builder.h +19 -19
  199. warp/native/warp.cpp +1057 -1052
  200. warp/native/warp.cu +2943 -2828
  201. warp/native/warp.h +313 -305
  202. warp/optim/__init__.py +9 -9
  203. warp/optim/adam.py +120 -120
  204. warp/optim/linear.py +1104 -939
  205. warp/optim/sgd.py +104 -92
  206. warp/render/__init__.py +10 -10
  207. warp/render/render_opengl.py +3217 -3204
  208. warp/render/render_usd.py +768 -749
  209. warp/render/utils.py +152 -150
  210. warp/sim/__init__.py +52 -59
  211. warp/sim/articulation.py +685 -685
  212. warp/sim/collide.py +1594 -1590
  213. warp/sim/import_mjcf.py +489 -481
  214. warp/sim/import_snu.py +220 -221
  215. warp/sim/import_urdf.py +536 -516
  216. warp/sim/import_usd.py +887 -881
  217. warp/sim/inertia.py +316 -317
  218. warp/sim/integrator.py +234 -233
  219. warp/sim/integrator_euler.py +1956 -1956
  220. warp/sim/integrator_featherstone.py +1910 -1991
  221. warp/sim/integrator_xpbd.py +3294 -3312
  222. warp/sim/model.py +4473 -4314
  223. warp/sim/particles.py +113 -112
  224. warp/sim/render.py +417 -403
  225. warp/sim/utils.py +413 -410
  226. warp/sparse.py +1227 -1227
  227. warp/stubs.py +2109 -2469
  228. warp/tape.py +1162 -225
  229. warp/tests/__init__.py +1 -1
  230. warp/tests/__main__.py +4 -4
  231. warp/tests/assets/torus.usda +105 -105
  232. warp/tests/aux_test_class_kernel.py +26 -26
  233. warp/tests/aux_test_compile_consts_dummy.py +10 -10
  234. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
  235. warp/tests/aux_test_dependent.py +22 -22
  236. warp/tests/aux_test_grad_customs.py +23 -23
  237. warp/tests/aux_test_reference.py +11 -11
  238. warp/tests/aux_test_reference_reference.py +10 -10
  239. warp/tests/aux_test_square.py +17 -17
  240. warp/tests/aux_test_unresolved_func.py +14 -14
  241. warp/tests/aux_test_unresolved_symbol.py +14 -14
  242. warp/tests/disabled_kinematics.py +239 -239
  243. warp/tests/run_coverage_serial.py +31 -31
  244. warp/tests/test_adam.py +157 -157
  245. warp/tests/test_arithmetic.py +1124 -1124
  246. warp/tests/test_array.py +2417 -2326
  247. warp/tests/test_array_reduce.py +150 -150
  248. warp/tests/test_async.py +668 -656
  249. warp/tests/test_atomic.py +141 -141
  250. warp/tests/test_bool.py +204 -149
  251. warp/tests/test_builtins_resolution.py +1292 -1292
  252. warp/tests/test_bvh.py +164 -171
  253. warp/tests/test_closest_point_edge_edge.py +228 -228
  254. warp/tests/test_codegen.py +566 -553
  255. warp/tests/test_compile_consts.py +97 -101
  256. warp/tests/test_conditional.py +246 -246
  257. warp/tests/test_copy.py +232 -215
  258. warp/tests/test_ctypes.py +632 -632
  259. warp/tests/test_dense.py +67 -67
  260. warp/tests/test_devices.py +91 -98
  261. warp/tests/test_dlpack.py +530 -529
  262. warp/tests/test_examples.py +400 -378
  263. warp/tests/test_fabricarray.py +955 -955
  264. warp/tests/test_fast_math.py +62 -54
  265. warp/tests/test_fem.py +1277 -1278
  266. warp/tests/test_fp16.py +130 -130
  267. warp/tests/test_func.py +338 -337
  268. warp/tests/test_generics.py +571 -571
  269. warp/tests/test_grad.py +746 -640
  270. warp/tests/test_grad_customs.py +333 -336
  271. warp/tests/test_hash_grid.py +210 -164
  272. warp/tests/test_import.py +39 -39
  273. warp/tests/test_indexedarray.py +1134 -1134
  274. warp/tests/test_intersect.py +67 -67
  275. warp/tests/test_jax.py +307 -307
  276. warp/tests/test_large.py +167 -164
  277. warp/tests/test_launch.py +354 -354
  278. warp/tests/test_lerp.py +261 -261
  279. warp/tests/test_linear_solvers.py +191 -171
  280. warp/tests/test_lvalue.py +421 -493
  281. warp/tests/test_marching_cubes.py +65 -65
  282. warp/tests/test_mat.py +1801 -1827
  283. warp/tests/test_mat_lite.py +115 -115
  284. warp/tests/test_mat_scalar_ops.py +2907 -2889
  285. warp/tests/test_math.py +126 -193
  286. warp/tests/test_matmul.py +500 -499
  287. warp/tests/test_matmul_lite.py +410 -410
  288. warp/tests/test_mempool.py +188 -190
  289. warp/tests/test_mesh.py +284 -324
  290. warp/tests/test_mesh_query_aabb.py +228 -241
  291. warp/tests/test_mesh_query_point.py +692 -702
  292. warp/tests/test_mesh_query_ray.py +292 -303
  293. warp/tests/test_mlp.py +276 -276
  294. warp/tests/test_model.py +110 -110
  295. warp/tests/test_modules_lite.py +39 -39
  296. warp/tests/test_multigpu.py +163 -163
  297. warp/tests/test_noise.py +248 -248
  298. warp/tests/test_operators.py +250 -250
  299. warp/tests/test_options.py +123 -125
  300. warp/tests/test_peer.py +133 -137
  301. warp/tests/test_pinned.py +78 -78
  302. warp/tests/test_print.py +54 -54
  303. warp/tests/test_quat.py +2086 -2086
  304. warp/tests/test_rand.py +288 -288
  305. warp/tests/test_reload.py +217 -217
  306. warp/tests/test_rounding.py +179 -179
  307. warp/tests/test_runlength_encode.py +190 -190
  308. warp/tests/test_sim_grad.py +243 -0
  309. warp/tests/test_sim_kinematics.py +91 -97
  310. warp/tests/test_smoothstep.py +168 -168
  311. warp/tests/test_snippet.py +305 -266
  312. warp/tests/test_sparse.py +468 -460
  313. warp/tests/test_spatial.py +2148 -2148
  314. warp/tests/test_streams.py +486 -473
  315. warp/tests/test_struct.py +710 -675
  316. warp/tests/test_tape.py +173 -148
  317. warp/tests/test_torch.py +743 -743
  318. warp/tests/test_transient_module.py +87 -87
  319. warp/tests/test_types.py +556 -659
  320. warp/tests/test_utils.py +490 -499
  321. warp/tests/test_vec.py +1264 -1268
  322. warp/tests/test_vec_lite.py +73 -73
  323. warp/tests/test_vec_scalar_ops.py +2099 -2099
  324. warp/tests/test_verify_fp.py +94 -94
  325. warp/tests/test_volume.py +737 -736
  326. warp/tests/test_volume_write.py +255 -265
  327. warp/tests/unittest_serial.py +37 -37
  328. warp/tests/unittest_suites.py +363 -359
  329. warp/tests/unittest_utils.py +603 -578
  330. warp/tests/unused_test_misc.py +71 -71
  331. warp/tests/walkthrough_debug.py +85 -85
  332. warp/thirdparty/appdirs.py +598 -598
  333. warp/thirdparty/dlpack.py +143 -143
  334. warp/thirdparty/unittest_parallel.py +566 -561
  335. warp/torch.py +321 -295
  336. warp/types.py +4504 -4450
  337. warp/utils.py +1008 -821
  338. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
  339. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
  340. warp_lang-1.1.0.dist-info/RECORD +352 -0
  341. warp/examples/assets/cube.usda +0 -42
  342. warp/examples/assets/sphere.usda +0 -56
  343. warp/examples/assets/torus.usda +0 -105
  344. warp_lang-1.0.1.dist-info/RECORD +0 -352
  345. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
  346. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,2099 +1,2099 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
-
8
- import unittest
9
-
10
- import numpy as np
11
-
12
- import warp as wp
13
- from warp.tests.unittest_utils import *
14
-
15
- wp.init()
16
-
17
- np_signed_int_types = [
18
- np.int8,
19
- np.int16,
20
- np.int32,
21
- np.int64,
22
- np.byte,
23
- ]
24
-
25
- np_unsigned_int_types = [
26
- np.uint8,
27
- np.uint16,
28
- np.uint32,
29
- np.uint64,
30
- np.ubyte,
31
- ]
32
-
33
- np_int_types = np_signed_int_types + np_unsigned_int_types
34
-
35
- np_float_types = [np.float16, np.float32, np.float64]
36
-
37
- np_scalar_types = np_int_types + np_float_types
38
-
39
-
40
- def randvals(rng, shape, dtype):
41
- if dtype in np_float_types:
42
- return rng.standard_normal(size=shape).astype(dtype)
43
- elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
44
- return rng.integers(1, high=3, size=shape, dtype=dtype)
45
- return rng.integers(1, high=5, size=shape, dtype=dtype)
46
-
47
-
48
- kernel_cache = dict()
49
-
50
-
51
- def getkernel(func, suffix=""):
52
- key = func.__name__ + "_" + suffix
53
- if key not in kernel_cache:
54
- kernel_cache[key] = wp.Kernel(func=func, key=key)
55
- return kernel_cache[key]
56
-
57
-
58
- def get_select_kernel(dtype):
59
- def output_select_kernel_fn(
60
- input: wp.array(dtype=dtype),
61
- index: int,
62
- out: wp.array(dtype=dtype),
63
- ):
64
- out[0] = input[index]
65
-
66
- return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
67
-
68
-
69
- def get_select_kernel2(dtype):
70
- def output_select_kernel2_fn(
71
- input: wp.array(dtype=dtype, ndim=2),
72
- index0: int,
73
- index1: int,
74
- out: wp.array(dtype=dtype),
75
- ):
76
- out[0] = input[index0, index1]
77
-
78
- return getkernel(output_select_kernel2_fn, suffix=dtype.__name__)
79
-
80
-
81
- def test_arrays(test, device, dtype):
82
- rng = np.random.default_rng(123)
83
-
84
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
85
- vec2 = wp.types.vector(length=2, dtype=wptype)
86
- vec3 = wp.types.vector(length=3, dtype=wptype)
87
- vec4 = wp.types.vector(length=4, dtype=wptype)
88
- vec5 = wp.types.vector(length=5, dtype=wptype)
89
-
90
- v2_np = randvals(rng, (10, 2), dtype)
91
- v3_np = randvals(rng, (10, 3), dtype)
92
- v4_np = randvals(rng, (10, 4), dtype)
93
- v5_np = randvals(rng, (10, 5), dtype)
94
-
95
- v2 = wp.array(v2_np, dtype=vec2, requires_grad=True, device=device)
96
- v3 = wp.array(v3_np, dtype=vec3, requires_grad=True, device=device)
97
- v4 = wp.array(v4_np, dtype=vec4, requires_grad=True, device=device)
98
- v5 = wp.array(v5_np, dtype=vec5, requires_grad=True, device=device)
99
-
100
- assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
101
- assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
102
- assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
103
- assert_np_equal(v5.numpy(), v5_np, tol=1.0e-6)
104
-
105
- vec2 = wp.types.vector(length=2, dtype=wptype)
106
- vec3 = wp.types.vector(length=3, dtype=wptype)
107
- vec4 = wp.types.vector(length=4, dtype=wptype)
108
-
109
- v2 = wp.array(v2_np, dtype=vec2, requires_grad=True, device=device)
110
- v3 = wp.array(v3_np, dtype=vec3, requires_grad=True, device=device)
111
- v4 = wp.array(v4_np, dtype=vec4, requires_grad=True, device=device)
112
-
113
- assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
114
- assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
115
- assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
116
-
117
-
118
- def test_components(test, device, dtype):
119
- # test accessing vector components from Python - this is especially important
120
- # for float16, which requires special handling internally
121
-
122
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
123
- vec3 = wp.types.vector(length=3, dtype=wptype)
124
-
125
- v = vec3(1, 2, 3)
126
-
127
- # test __getitem__ for individual components
128
- test.assertEqual(v[0], 1)
129
- test.assertEqual(v[1], 2)
130
- test.assertEqual(v[2], 3)
131
-
132
- # test __getitem__ for slices
133
- s = v[:]
134
- test.assertEqual(s[0], 1)
135
- test.assertEqual(s[1], 2)
136
- test.assertEqual(s[2], 3)
137
-
138
- s = v[1:]
139
- test.assertEqual(s[0], 2)
140
- test.assertEqual(s[1], 3)
141
-
142
- s = v[:2]
143
- test.assertEqual(s[0], 1)
144
- test.assertEqual(s[1], 2)
145
-
146
- s = v[::2]
147
- test.assertEqual(s[0], 1)
148
- test.assertEqual(s[1], 3)
149
-
150
- # test __setitem__ for individual components
151
- v[0] = 4
152
- v[1] = 5
153
- v[2] = 6
154
- test.assertEqual(v[0], 4)
155
- test.assertEqual(v[1], 5)
156
- test.assertEqual(v[2], 6)
157
-
158
- # test __setitem__ for slices
159
- v[:] = [7, 8, 9]
160
- test.assertEqual(v[0], 7)
161
- test.assertEqual(v[1], 8)
162
- test.assertEqual(v[2], 9)
163
-
164
- v[1:] = [10, 11]
165
- test.assertEqual(v[0], 7)
166
- test.assertEqual(v[1], 10)
167
- test.assertEqual(v[2], 11)
168
-
169
- v[:2] = [12, 13]
170
- test.assertEqual(v[0], 12)
171
- test.assertEqual(v[1], 13)
172
- test.assertEqual(v[2], 11)
173
-
174
- v[::2] = [14, 15]
175
- test.assertEqual(v[0], 14)
176
- test.assertEqual(v[1], 13)
177
- test.assertEqual(v[2], 15)
178
-
179
-
180
- def test_py_arithmetic_ops(test, device, dtype):
181
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
182
-
183
- def make_vec(*args):
184
- if wptype in wp.types.int_types:
185
- # Cast to the correct integer type to simulate wrapping.
186
- return tuple(wptype._type_(x).value for x in args)
187
-
188
- return args
189
-
190
- vec_cls = wp.vec(3, wptype)
191
-
192
- v = vec_cls(1, -2, 3)
193
- test.assertSequenceEqual(+v, make_vec(1, -2, 3))
194
- test.assertSequenceEqual(-v, make_vec(-1, 2, -3))
195
- test.assertSequenceEqual(v + vec_cls(5, 5, 5), make_vec(6, 3, 8))
196
- test.assertSequenceEqual(v - vec_cls(5, 5, 5), make_vec(-4, -7, -2))
197
-
198
- v = vec_cls(2, 4, 6)
199
- test.assertSequenceEqual(v * wptype(2), make_vec(4, 8, 12))
200
- test.assertSequenceEqual(wptype(2) * v, make_vec(4, 8, 12))
201
- test.assertSequenceEqual(v / wptype(2), make_vec(1, 2, 3))
202
- test.assertSequenceEqual(wptype(24) / v, make_vec(12, 6, 4))
203
-
204
-
205
- def test_constructors(test, device, dtype, register_kernels=False):
206
- rng = np.random.default_rng(123)
207
-
208
- tol = {
209
- np.float16: 5.0e-3,
210
- np.float32: 1.0e-6,
211
- np.float64: 1.0e-8,
212
- }.get(dtype, 0)
213
-
214
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
215
- vec2 = wp.types.vector(length=2, dtype=wptype)
216
- vec3 = wp.types.vector(length=3, dtype=wptype)
217
- vec4 = wp.types.vector(length=4, dtype=wptype)
218
- vec5 = wp.types.vector(length=5, dtype=wptype)
219
-
220
- def check_scalar_constructor(
221
- input: wp.array(dtype=wptype),
222
- v2: wp.array(dtype=vec2),
223
- v3: wp.array(dtype=vec3),
224
- v4: wp.array(dtype=vec4),
225
- v5: wp.array(dtype=vec5),
226
- v20: wp.array(dtype=wptype),
227
- v21: wp.array(dtype=wptype),
228
- v30: wp.array(dtype=wptype),
229
- v31: wp.array(dtype=wptype),
230
- v32: wp.array(dtype=wptype),
231
- v40: wp.array(dtype=wptype),
232
- v41: wp.array(dtype=wptype),
233
- v42: wp.array(dtype=wptype),
234
- v43: wp.array(dtype=wptype),
235
- v50: wp.array(dtype=wptype),
236
- v51: wp.array(dtype=wptype),
237
- v52: wp.array(dtype=wptype),
238
- v53: wp.array(dtype=wptype),
239
- v54: wp.array(dtype=wptype),
240
- ):
241
- v2result = vec2(input[0])
242
- v3result = vec3(input[0])
243
- v4result = vec4(input[0])
244
- v5result = vec5(input[0])
245
-
246
- v2[0] = v2result
247
- v3[0] = v3result
248
- v4[0] = v4result
249
- v5[0] = v5result
250
-
251
- # multiply outputs by 2 so we've got something to backpropagate
252
- v20[0] = wptype(2) * v2result[0]
253
- v21[0] = wptype(2) * v2result[1]
254
-
255
- v30[0] = wptype(2) * v3result[0]
256
- v31[0] = wptype(2) * v3result[1]
257
- v32[0] = wptype(2) * v3result[2]
258
-
259
- v40[0] = wptype(2) * v4result[0]
260
- v41[0] = wptype(2) * v4result[1]
261
- v42[0] = wptype(2) * v4result[2]
262
- v43[0] = wptype(2) * v4result[3]
263
-
264
- v50[0] = wptype(2) * v5result[0]
265
- v51[0] = wptype(2) * v5result[1]
266
- v52[0] = wptype(2) * v5result[2]
267
- v53[0] = wptype(2) * v5result[3]
268
- v54[0] = wptype(2) * v5result[4]
269
-
270
- def check_vector_constructors(
271
- input: wp.array(dtype=wptype),
272
- v2: wp.array(dtype=vec2),
273
- v3: wp.array(dtype=vec3),
274
- v4: wp.array(dtype=vec4),
275
- v5: wp.array(dtype=vec5),
276
- v20: wp.array(dtype=wptype),
277
- v21: wp.array(dtype=wptype),
278
- v30: wp.array(dtype=wptype),
279
- v31: wp.array(dtype=wptype),
280
- v32: wp.array(dtype=wptype),
281
- v40: wp.array(dtype=wptype),
282
- v41: wp.array(dtype=wptype),
283
- v42: wp.array(dtype=wptype),
284
- v43: wp.array(dtype=wptype),
285
- v50: wp.array(dtype=wptype),
286
- v51: wp.array(dtype=wptype),
287
- v52: wp.array(dtype=wptype),
288
- v53: wp.array(dtype=wptype),
289
- v54: wp.array(dtype=wptype),
290
- ):
291
- v2result = vec2(input[0], input[1])
292
- v3result = vec3(input[2], input[3], input[4])
293
- v4result = vec4(input[5], input[6], input[7], input[8])
294
- v5result = vec5(input[9], input[10], input[11], input[12], input[13])
295
-
296
- v2[0] = v2result
297
- v3[0] = v3result
298
- v4[0] = v4result
299
- v5[0] = v5result
300
-
301
- # multiply the output by 2 so we've got something to backpropagate:
302
- v20[0] = wptype(2) * v2result[0]
303
- v21[0] = wptype(2) * v2result[1]
304
-
305
- v30[0] = wptype(2) * v3result[0]
306
- v31[0] = wptype(2) * v3result[1]
307
- v32[0] = wptype(2) * v3result[2]
308
-
309
- v40[0] = wptype(2) * v4result[0]
310
- v41[0] = wptype(2) * v4result[1]
311
- v42[0] = wptype(2) * v4result[2]
312
- v43[0] = wptype(2) * v4result[3]
313
-
314
- v50[0] = wptype(2) * v5result[0]
315
- v51[0] = wptype(2) * v5result[1]
316
- v52[0] = wptype(2) * v5result[2]
317
- v53[0] = wptype(2) * v5result[3]
318
- v54[0] = wptype(2) * v5result[4]
319
-
320
- vec_kernel = getkernel(check_vector_constructors, suffix=dtype.__name__)
321
- kernel = getkernel(check_scalar_constructor, suffix=dtype.__name__)
322
-
323
- if register_kernels:
324
- return
325
-
326
- input = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
327
- v2 = wp.zeros(1, dtype=vec2, device=device)
328
- v3 = wp.zeros(1, dtype=vec3, device=device)
329
- v4 = wp.zeros(1, dtype=vec4, device=device)
330
- v5 = wp.zeros(1, dtype=vec5, device=device)
331
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
332
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
333
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
334
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
335
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
336
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
337
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
338
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
339
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
340
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
341
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
342
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
343
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
344
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
345
-
346
- tape = wp.Tape()
347
- with tape:
348
- wp.launch(
349
- kernel,
350
- dim=1,
351
- inputs=[input],
352
- outputs=[v2, v3, v4, v5, v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
353
- device=device,
354
- )
355
-
356
- if dtype in np_float_types:
357
- for l in [v20, v21]:
358
- tape.backward(loss=l)
359
- test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
360
- tape.zero()
361
-
362
- for l in [v30, v31, v32]:
363
- tape.backward(loss=l)
364
- test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
365
- tape.zero()
366
-
367
- for l in [v40, v41, v42, v43]:
368
- tape.backward(loss=l)
369
- test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
370
- tape.zero()
371
-
372
- for l in [v50, v51, v52, v53, v54]:
373
- tape.backward(loss=l)
374
- test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
375
- tape.zero()
376
-
377
- val = input.numpy()[0]
378
- assert_np_equal(v2.numpy()[0], np.array([val, val]), tol=1.0e-6)
379
- assert_np_equal(v3.numpy()[0], np.array([val, val, val]), tol=1.0e-6)
380
- assert_np_equal(v4.numpy()[0], np.array([val, val, val, val]), tol=1.0e-6)
381
- assert_np_equal(v5.numpy()[0], np.array([val, val, val, val, val]), tol=1.0e-6)
382
-
383
- assert_np_equal(v20.numpy()[0], 2 * val, tol=1.0e-6)
384
- assert_np_equal(v21.numpy()[0], 2 * val, tol=1.0e-6)
385
- assert_np_equal(v30.numpy()[0], 2 * val, tol=1.0e-6)
386
- assert_np_equal(v31.numpy()[0], 2 * val, tol=1.0e-6)
387
- assert_np_equal(v32.numpy()[0], 2 * val, tol=1.0e-6)
388
- assert_np_equal(v40.numpy()[0], 2 * val, tol=1.0e-6)
389
- assert_np_equal(v41.numpy()[0], 2 * val, tol=1.0e-6)
390
- assert_np_equal(v42.numpy()[0], 2 * val, tol=1.0e-6)
391
- assert_np_equal(v43.numpy()[0], 2 * val, tol=1.0e-6)
392
- assert_np_equal(v50.numpy()[0], 2 * val, tol=1.0e-6)
393
- assert_np_equal(v51.numpy()[0], 2 * val, tol=1.0e-6)
394
- assert_np_equal(v52.numpy()[0], 2 * val, tol=1.0e-6)
395
- assert_np_equal(v53.numpy()[0], 2 * val, tol=1.0e-6)
396
- assert_np_equal(v54.numpy()[0], 2 * val, tol=1.0e-6)
397
-
398
- input = wp.array(randvals(rng, [14], dtype), requires_grad=True, device=device)
399
- tape = wp.Tape()
400
- with tape:
401
- wp.launch(
402
- vec_kernel,
403
- dim=1,
404
- inputs=[input],
405
- outputs=[v2, v3, v4, v5, v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
406
- device=device,
407
- )
408
-
409
- if dtype in np_float_types:
410
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
411
- tape.backward(loss=l)
412
- grad = tape.gradients[input].numpy()
413
- expected_grad = np.zeros_like(grad)
414
- expected_grad[i] = 2
415
- assert_np_equal(grad, expected_grad, tol=tol)
416
- tape.zero()
417
-
418
- assert_np_equal(v2.numpy()[0, 0], input.numpy()[0], tol=tol)
419
- assert_np_equal(v2.numpy()[0, 1], input.numpy()[1], tol=tol)
420
- assert_np_equal(v3.numpy()[0, 0], input.numpy()[2], tol=tol)
421
- assert_np_equal(v3.numpy()[0, 1], input.numpy()[3], tol=tol)
422
- assert_np_equal(v3.numpy()[0, 2], input.numpy()[4], tol=tol)
423
- assert_np_equal(v4.numpy()[0, 0], input.numpy()[5], tol=tol)
424
- assert_np_equal(v4.numpy()[0, 1], input.numpy()[6], tol=tol)
425
- assert_np_equal(v4.numpy()[0, 2], input.numpy()[7], tol=tol)
426
- assert_np_equal(v4.numpy()[0, 3], input.numpy()[8], tol=tol)
427
- assert_np_equal(v5.numpy()[0, 0], input.numpy()[9], tol=tol)
428
- assert_np_equal(v5.numpy()[0, 1], input.numpy()[10], tol=tol)
429
- assert_np_equal(v5.numpy()[0, 2], input.numpy()[11], tol=tol)
430
- assert_np_equal(v5.numpy()[0, 3], input.numpy()[12], tol=tol)
431
- assert_np_equal(v5.numpy()[0, 4], input.numpy()[13], tol=tol)
432
-
433
- assert_np_equal(v20.numpy()[0], 2 * input.numpy()[0], tol=tol)
434
- assert_np_equal(v21.numpy()[0], 2 * input.numpy()[1], tol=tol)
435
- assert_np_equal(v30.numpy()[0], 2 * input.numpy()[2], tol=tol)
436
- assert_np_equal(v31.numpy()[0], 2 * input.numpy()[3], tol=tol)
437
- assert_np_equal(v32.numpy()[0], 2 * input.numpy()[4], tol=tol)
438
- assert_np_equal(v40.numpy()[0], 2 * input.numpy()[5], tol=tol)
439
- assert_np_equal(v41.numpy()[0], 2 * input.numpy()[6], tol=tol)
440
- assert_np_equal(v42.numpy()[0], 2 * input.numpy()[7], tol=tol)
441
- assert_np_equal(v43.numpy()[0], 2 * input.numpy()[8], tol=tol)
442
- assert_np_equal(v50.numpy()[0], 2 * input.numpy()[9], tol=tol)
443
- assert_np_equal(v51.numpy()[0], 2 * input.numpy()[10], tol=tol)
444
- assert_np_equal(v52.numpy()[0], 2 * input.numpy()[11], tol=tol)
445
- assert_np_equal(v53.numpy()[0], 2 * input.numpy()[12], tol=tol)
446
- assert_np_equal(v54.numpy()[0], 2 * input.numpy()[13], tol=tol)
447
-
448
-
449
- def test_anon_type_instance(test, device, dtype, register_kernels=False):
450
- rng = np.random.default_rng(123)
451
-
452
- tol = {
453
- np.float16: 5.0e-3,
454
- np.float32: 1.0e-6,
455
- np.float64: 1.0e-8,
456
- }.get(dtype, 0)
457
-
458
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
459
-
460
- def check_scalar_init(
461
- input: wp.array(dtype=wptype),
462
- output: wp.array(dtype=wptype),
463
- ):
464
- v2result = wp.vector(input[0], length=2)
465
- v3result = wp.vector(input[1], length=3)
466
- v4result = wp.vector(input[2], length=4)
467
- v5result = wp.vector(input[3], length=5)
468
-
469
- idx = 0
470
- for i in range(2):
471
- output[idx] = wptype(2) * v2result[i]
472
- idx = idx + 1
473
- for i in range(3):
474
- output[idx] = wptype(2) * v3result[i]
475
- idx = idx + 1
476
- for i in range(4):
477
- output[idx] = wptype(2) * v4result[i]
478
- idx = idx + 1
479
- for i in range(5):
480
- output[idx] = wptype(2) * v5result[i]
481
- idx = idx + 1
482
-
483
- def check_component_init(
484
- input: wp.array(dtype=wptype),
485
- output: wp.array(dtype=wptype),
486
- ):
487
- v2result = wp.vector(input[0], input[1])
488
- v3result = wp.vector(input[2], input[3], input[4])
489
- v4result = wp.vector(input[5], input[6], input[7], input[8])
490
- v5result = wp.vector(input[9], input[10], input[11], input[12], input[13])
491
-
492
- idx = 0
493
- for i in range(2):
494
- output[idx] = wptype(2) * v2result[i]
495
- idx = idx + 1
496
- for i in range(3):
497
- output[idx] = wptype(2) * v3result[i]
498
- idx = idx + 1
499
- for i in range(4):
500
- output[idx] = wptype(2) * v4result[i]
501
- idx = idx + 1
502
- for i in range(5):
503
- output[idx] = wptype(2) * v5result[i]
504
- idx = idx + 1
505
-
506
- scalar_kernel = getkernel(check_scalar_init, suffix=dtype.__name__)
507
- component_kernel = getkernel(check_component_init, suffix=dtype.__name__)
508
- output_select_kernel = get_select_kernel(wptype)
509
-
510
- if register_kernels:
511
- return
512
-
513
- input = wp.array(randvals(rng, [4], dtype), requires_grad=True, device=device)
514
- output = wp.zeros(2 + 3 + 4 + 5, dtype=wptype, requires_grad=True, device=device)
515
-
516
- wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
517
-
518
- assert_np_equal(output.numpy()[:2], 2 * np.array([input.numpy()[0]] * 2), tol=1.0e-6)
519
- assert_np_equal(output.numpy()[2:5], 2 * np.array([input.numpy()[1]] * 3), tol=1.0e-6)
520
- assert_np_equal(output.numpy()[5:9], 2 * np.array([input.numpy()[2]] * 4), tol=1.0e-6)
521
- assert_np_equal(output.numpy()[9:], 2 * np.array([input.numpy()[3]] * 5), tol=1.0e-6)
522
-
523
- if dtype in np_float_types:
524
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
525
- for i in range(len(output)):
526
- tape = wp.Tape()
527
- with tape:
528
- wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
529
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
530
-
531
- tape.backward(loss=out)
532
- expected = np.zeros_like(input.numpy())
533
- if i < 2:
534
- expected[0] = 2
535
- elif i < 5:
536
- expected[1] = 2
537
- elif i < 9:
538
- expected[2] = 2
539
- else:
540
- expected[3] = 2
541
-
542
- assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
543
-
544
- tape.reset()
545
- tape.zero()
546
-
547
- input = wp.array(randvals(rng, [2 + 3 + 4 + 5], dtype), requires_grad=True, device=device)
548
- output = wp.zeros(2 + 3 + 4 + 5, dtype=wptype, requires_grad=True, device=device)
549
-
550
- wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
551
-
552
- assert_np_equal(output.numpy(), 2 * input.numpy(), tol=1.0e-6)
553
-
554
- if dtype in np_float_types:
555
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
556
- for i in range(len(output)):
557
- tape = wp.Tape()
558
- with tape:
559
- wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
560
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
561
-
562
- tape.backward(loss=out)
563
- expected = np.zeros_like(input.numpy())
564
- expected[i] = 2
565
-
566
- assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
567
-
568
- tape.reset()
569
- tape.zero()
570
-
571
-
572
- def test_indexing(test, device, dtype, register_kernels=False):
573
- rng = np.random.default_rng(123)
574
-
575
- tol = {
576
- np.float16: 5.0e-3,
577
- np.float32: 1.0e-6,
578
- np.float64: 1.0e-8,
579
- }.get(dtype, 0)
580
-
581
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
582
- vec2 = wp.types.vector(length=2, dtype=wptype)
583
- vec3 = wp.types.vector(length=3, dtype=wptype)
584
- vec4 = wp.types.vector(length=4, dtype=wptype)
585
- vec5 = wp.types.vector(length=5, dtype=wptype)
586
-
587
- def check_indexing(
588
- v2: wp.array(dtype=vec2),
589
- v3: wp.array(dtype=vec3),
590
- v4: wp.array(dtype=vec4),
591
- v5: wp.array(dtype=vec5),
592
- v20: wp.array(dtype=wptype),
593
- v21: wp.array(dtype=wptype),
594
- v30: wp.array(dtype=wptype),
595
- v31: wp.array(dtype=wptype),
596
- v32: wp.array(dtype=wptype),
597
- v40: wp.array(dtype=wptype),
598
- v41: wp.array(dtype=wptype),
599
- v42: wp.array(dtype=wptype),
600
- v43: wp.array(dtype=wptype),
601
- v50: wp.array(dtype=wptype),
602
- v51: wp.array(dtype=wptype),
603
- v52: wp.array(dtype=wptype),
604
- v53: wp.array(dtype=wptype),
605
- v54: wp.array(dtype=wptype),
606
- ):
607
- # multiply outputs by 2 so we've got something to backpropagate:
608
- v20[0] = wptype(2) * v2[0][0]
609
- v21[0] = wptype(2) * v2[0][1]
610
-
611
- v30[0] = wptype(2) * v3[0][0]
612
- v31[0] = wptype(2) * v3[0][1]
613
- v32[0] = wptype(2) * v3[0][2]
614
-
615
- v40[0] = wptype(2) * v4[0][0]
616
- v41[0] = wptype(2) * v4[0][1]
617
- v42[0] = wptype(2) * v4[0][2]
618
- v43[0] = wptype(2) * v4[0][3]
619
-
620
- v50[0] = wptype(2) * v5[0][0]
621
- v51[0] = wptype(2) * v5[0][1]
622
- v52[0] = wptype(2) * v5[0][2]
623
- v53[0] = wptype(2) * v5[0][3]
624
- v54[0] = wptype(2) * v5[0][4]
625
-
626
- kernel = getkernel(check_indexing, suffix=dtype.__name__)
627
-
628
- if register_kernels:
629
- return
630
-
631
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
632
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
633
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
634
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
635
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
636
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
637
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
638
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
639
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
640
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
641
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
642
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
643
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
644
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
645
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
646
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
647
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
648
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
649
-
650
- tape = wp.Tape()
651
- with tape:
652
- wp.launch(
653
- kernel,
654
- dim=1,
655
- inputs=[v2, v3, v4, v5],
656
- outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
657
- device=device,
658
- )
659
-
660
- if dtype in np_float_types:
661
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
662
- tape.backward(loss=l)
663
- allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
664
- expected_grads = np.zeros_like(allgrads)
665
- expected_grads[i] = 2
666
- assert_np_equal(allgrads, expected_grads, tol=tol)
667
- tape.zero()
668
-
669
- assert_np_equal(v20.numpy()[0], 2.0 * v2.numpy()[0, 0], tol=tol)
670
- assert_np_equal(v21.numpy()[0], 2.0 * v2.numpy()[0, 1], tol=tol)
671
- assert_np_equal(v30.numpy()[0], 2.0 * v3.numpy()[0, 0], tol=tol)
672
- assert_np_equal(v31.numpy()[0], 2.0 * v3.numpy()[0, 1], tol=tol)
673
- assert_np_equal(v32.numpy()[0], 2.0 * v3.numpy()[0, 2], tol=tol)
674
- assert_np_equal(v40.numpy()[0], 2.0 * v4.numpy()[0, 0], tol=tol)
675
- assert_np_equal(v41.numpy()[0], 2.0 * v4.numpy()[0, 1], tol=tol)
676
- assert_np_equal(v42.numpy()[0], 2.0 * v4.numpy()[0, 2], tol=tol)
677
- assert_np_equal(v43.numpy()[0], 2.0 * v4.numpy()[0, 3], tol=tol)
678
- assert_np_equal(v50.numpy()[0], 2.0 * v5.numpy()[0, 0], tol=tol)
679
- assert_np_equal(v51.numpy()[0], 2.0 * v5.numpy()[0, 1], tol=tol)
680
- assert_np_equal(v52.numpy()[0], 2.0 * v5.numpy()[0, 2], tol=tol)
681
- assert_np_equal(v53.numpy()[0], 2.0 * v5.numpy()[0, 3], tol=tol)
682
- assert_np_equal(v54.numpy()[0], 2.0 * v5.numpy()[0, 4], tol=tol)
683
-
684
-
685
- def test_equality(test, device, dtype, register_kernels=False):
686
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
687
- vec2 = wp.types.vector(length=2, dtype=wptype)
688
- vec3 = wp.types.vector(length=3, dtype=wptype)
689
- vec4 = wp.types.vector(length=4, dtype=wptype)
690
- vec5 = wp.types.vector(length=5, dtype=wptype)
691
-
692
- def check_equality(
693
- v20: wp.array(dtype=vec2),
694
- v21: wp.array(dtype=vec2),
695
- v22: wp.array(dtype=vec2),
696
- v30: wp.array(dtype=vec3),
697
- v31: wp.array(dtype=vec3),
698
- v32: wp.array(dtype=vec3),
699
- v33: wp.array(dtype=vec3),
700
- v40: wp.array(dtype=vec4),
701
- v41: wp.array(dtype=vec4),
702
- v42: wp.array(dtype=vec4),
703
- v43: wp.array(dtype=vec4),
704
- v44: wp.array(dtype=vec4),
705
- v50: wp.array(dtype=vec5),
706
- v51: wp.array(dtype=vec5),
707
- v52: wp.array(dtype=vec5),
708
- v53: wp.array(dtype=vec5),
709
- v54: wp.array(dtype=vec5),
710
- v55: wp.array(dtype=vec5),
711
- ):
712
- wp.expect_eq(v20[0], v20[0])
713
- wp.expect_neq(v21[0], v20[0])
714
- wp.expect_neq(v22[0], v20[0])
715
-
716
- wp.expect_eq(v30[0], v30[0])
717
- wp.expect_neq(v31[0], v30[0])
718
- wp.expect_neq(v32[0], v30[0])
719
- wp.expect_neq(v33[0], v30[0])
720
-
721
- wp.expect_eq(v40[0], v40[0])
722
- wp.expect_neq(v41[0], v40[0])
723
- wp.expect_neq(v42[0], v40[0])
724
- wp.expect_neq(v43[0], v40[0])
725
- wp.expect_neq(v44[0], v40[0])
726
-
727
- wp.expect_eq(v50[0], v50[0])
728
- wp.expect_neq(v51[0], v50[0])
729
- wp.expect_neq(v52[0], v50[0])
730
- wp.expect_neq(v53[0], v50[0])
731
- wp.expect_neq(v54[0], v50[0])
732
- wp.expect_neq(v55[0], v50[0])
733
-
734
- kernel = getkernel(check_equality, suffix=dtype.__name__)
735
-
736
- if register_kernels:
737
- return
738
-
739
- v20 = wp.array([1.0, 2.0], dtype=vec2, requires_grad=True, device=device)
740
- v21 = wp.array([1.0, 3.0], dtype=vec2, requires_grad=True, device=device)
741
- v22 = wp.array([3.0, 2.0], dtype=vec2, requires_grad=True, device=device)
742
-
743
- v30 = wp.array([1.0, 2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
744
- v31 = wp.array([-1.0, 2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
745
- v32 = wp.array([1.0, -2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
746
- v33 = wp.array([1.0, 2.0, -3.0], dtype=vec3, requires_grad=True, device=device)
747
-
748
- v40 = wp.array([1.0, 2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
749
- v41 = wp.array([-1.0, 2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
750
- v42 = wp.array([1.0, -2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
751
- v43 = wp.array([1.0, 2.0, -3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
752
- v44 = wp.array([1.0, 2.0, 3.0, -4.0], dtype=vec4, requires_grad=True, device=device)
753
-
754
- v50 = wp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
755
- v51 = wp.array([-1.0, 2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
756
- v52 = wp.array([1.0, -2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
757
- v53 = wp.array([1.0, 2.0, -3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
758
- v54 = wp.array([1.0, 2.0, 3.0, -4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
759
- v55 = wp.array([1.0, 2.0, 3.0, 4.0, -5.0], dtype=vec5, requires_grad=True, device=device)
760
- wp.launch(
761
- kernel,
762
- dim=1,
763
- inputs=[
764
- v20,
765
- v21,
766
- v22,
767
- v30,
768
- v31,
769
- v32,
770
- v33,
771
- v40,
772
- v41,
773
- v42,
774
- v43,
775
- v44,
776
- v50,
777
- v51,
778
- v52,
779
- v53,
780
- v54,
781
- v55,
782
- ],
783
- outputs=[],
784
- device=device,
785
- )
786
-
787
-
788
- def test_scalar_multiplication(test, device, dtype, register_kernels=False):
789
- rng = np.random.default_rng(123)
790
-
791
- tol = {
792
- np.float16: 5.0e-3,
793
- np.float32: 1.0e-6,
794
- np.float64: 1.0e-8,
795
- }.get(dtype, 0)
796
-
797
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
798
- vec2 = wp.types.vector(length=2, dtype=wptype)
799
- vec3 = wp.types.vector(length=3, dtype=wptype)
800
- vec4 = wp.types.vector(length=4, dtype=wptype)
801
- vec5 = wp.types.vector(length=5, dtype=wptype)
802
-
803
- def check_mul(
804
- s: wp.array(dtype=wptype),
805
- v2: wp.array(dtype=vec2),
806
- v3: wp.array(dtype=vec3),
807
- v4: wp.array(dtype=vec4),
808
- v5: wp.array(dtype=vec5),
809
- v20: wp.array(dtype=wptype),
810
- v21: wp.array(dtype=wptype),
811
- v30: wp.array(dtype=wptype),
812
- v31: wp.array(dtype=wptype),
813
- v32: wp.array(dtype=wptype),
814
- v40: wp.array(dtype=wptype),
815
- v41: wp.array(dtype=wptype),
816
- v42: wp.array(dtype=wptype),
817
- v43: wp.array(dtype=wptype),
818
- v50: wp.array(dtype=wptype),
819
- v51: wp.array(dtype=wptype),
820
- v52: wp.array(dtype=wptype),
821
- v53: wp.array(dtype=wptype),
822
- v54: wp.array(dtype=wptype),
823
- ):
824
- v2result = s[0] * v2[0]
825
- v3result = s[0] * v3[0]
826
- v4result = s[0] * v4[0]
827
- v5result = s[0] * v5[0]
828
-
829
- # multiply outputs by 2 so we've got something to backpropagate:
830
- v20[0] = wptype(2) * v2result[0]
831
- v21[0] = wptype(2) * v2result[1]
832
-
833
- v30[0] = wptype(2) * v3result[0]
834
- v31[0] = wptype(2) * v3result[1]
835
- v32[0] = wptype(2) * v3result[2]
836
-
837
- v40[0] = wptype(2) * v4result[0]
838
- v41[0] = wptype(2) * v4result[1]
839
- v42[0] = wptype(2) * v4result[2]
840
- v43[0] = wptype(2) * v4result[3]
841
-
842
- v50[0] = wptype(2) * v5result[0]
843
- v51[0] = wptype(2) * v5result[1]
844
- v52[0] = wptype(2) * v5result[2]
845
- v53[0] = wptype(2) * v5result[3]
846
- v54[0] = wptype(2) * v5result[4]
847
-
848
- kernel = getkernel(check_mul, suffix=dtype.__name__)
849
-
850
- if register_kernels:
851
- return
852
-
853
- s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
854
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
855
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
856
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
857
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
858
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
859
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
860
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
861
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
862
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
863
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
864
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
865
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
866
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
867
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
868
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
869
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
870
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
871
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
872
- tape = wp.Tape()
873
- with tape:
874
- wp.launch(
875
- kernel,
876
- dim=1,
877
- inputs=[
878
- s,
879
- v2,
880
- v3,
881
- v4,
882
- v5,
883
- ],
884
- outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
885
- device=device,
886
- )
887
-
888
- assert_np_equal(v20.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 0], tol=tol)
889
- assert_np_equal(v21.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 1], tol=tol)
890
-
891
- assert_np_equal(v30.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 0], tol=10 * tol)
892
- assert_np_equal(v31.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 1], tol=10 * tol)
893
- assert_np_equal(v32.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 2], tol=10 * tol)
894
-
895
- assert_np_equal(v40.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 0], tol=10 * tol)
896
- assert_np_equal(v41.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 1], tol=10 * tol)
897
- assert_np_equal(v42.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 2], tol=10 * tol)
898
- assert_np_equal(v43.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 3], tol=10 * tol)
899
-
900
- assert_np_equal(v50.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 0], tol=10 * tol)
901
- assert_np_equal(v51.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 1], tol=10 * tol)
902
- assert_np_equal(v52.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 2], tol=10 * tol)
903
- assert_np_equal(v53.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 3], tol=10 * tol)
904
- assert_np_equal(v54.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 4], tol=10 * tol)
905
-
906
- incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
907
-
908
- if dtype in np_float_types:
909
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43]):
910
- tape.backward(loss=l)
911
- sgrad = tape.gradients[s].numpy()[0]
912
- assert_np_equal(sgrad, 2 * incmps[i], tol=10 * tol)
913
- allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4]])
914
- expected_grads = np.zeros_like(allgrads)
915
- expected_grads[i] = s.numpy()[0] * 2
916
- assert_np_equal(allgrads, expected_grads, tol=10 * tol)
917
- tape.zero()
918
-
919
-
920
- def test_scalar_multiplication_rightmul(test, device, dtype, register_kernels=False):
921
- rng = np.random.default_rng(123)
922
-
923
- tol = {
924
- np.float16: 5.0e-3,
925
- np.float32: 1.0e-6,
926
- np.float64: 1.0e-8,
927
- }.get(dtype, 0)
928
-
929
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
930
- vec2 = wp.types.vector(length=2, dtype=wptype)
931
- vec3 = wp.types.vector(length=3, dtype=wptype)
932
- vec4 = wp.types.vector(length=4, dtype=wptype)
933
- vec5 = wp.types.vector(length=5, dtype=wptype)
934
-
935
- def check_rightmul(
936
- s: wp.array(dtype=wptype),
937
- v2: wp.array(dtype=vec2),
938
- v3: wp.array(dtype=vec3),
939
- v4: wp.array(dtype=vec4),
940
- v5: wp.array(dtype=vec5),
941
- v20: wp.array(dtype=wptype),
942
- v21: wp.array(dtype=wptype),
943
- v30: wp.array(dtype=wptype),
944
- v31: wp.array(dtype=wptype),
945
- v32: wp.array(dtype=wptype),
946
- v40: wp.array(dtype=wptype),
947
- v41: wp.array(dtype=wptype),
948
- v42: wp.array(dtype=wptype),
949
- v43: wp.array(dtype=wptype),
950
- v50: wp.array(dtype=wptype),
951
- v51: wp.array(dtype=wptype),
952
- v52: wp.array(dtype=wptype),
953
- v53: wp.array(dtype=wptype),
954
- v54: wp.array(dtype=wptype),
955
- ):
956
- v2result = v2[0] * s[0]
957
- v3result = v3[0] * s[0]
958
- v4result = v4[0] * s[0]
959
- v5result = v5[0] * s[0]
960
-
961
- # multiply outputs by 2 so we've got something to backpropagate:
962
- v20[0] = wptype(2) * v2result[0]
963
- v21[0] = wptype(2) * v2result[1]
964
-
965
- v30[0] = wptype(2) * v3result[0]
966
- v31[0] = wptype(2) * v3result[1]
967
- v32[0] = wptype(2) * v3result[2]
968
-
969
- v40[0] = wptype(2) * v4result[0]
970
- v41[0] = wptype(2) * v4result[1]
971
- v42[0] = wptype(2) * v4result[2]
972
- v43[0] = wptype(2) * v4result[3]
973
-
974
- v50[0] = wptype(2) * v5result[0]
975
- v51[0] = wptype(2) * v5result[1]
976
- v52[0] = wptype(2) * v5result[2]
977
- v53[0] = wptype(2) * v5result[3]
978
- v54[0] = wptype(2) * v5result[4]
979
-
980
- kernel = getkernel(check_rightmul, suffix=dtype.__name__)
981
-
982
- if register_kernels:
983
- return
984
-
985
- s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
986
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
987
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
988
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
989
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
990
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
991
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
992
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
993
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
994
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
995
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
996
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
997
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
998
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
999
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1000
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1001
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1002
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1003
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1004
- tape = wp.Tape()
1005
- with tape:
1006
- wp.launch(
1007
- kernel,
1008
- dim=1,
1009
- inputs=[
1010
- s,
1011
- v2,
1012
- v3,
1013
- v4,
1014
- v5,
1015
- ],
1016
- outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1017
- device=device,
1018
- )
1019
-
1020
- assert_np_equal(v20.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 0], tol=tol)
1021
- assert_np_equal(v21.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 1], tol=tol)
1022
-
1023
- assert_np_equal(v30.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 0], tol=10 * tol)
1024
- assert_np_equal(v31.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 1], tol=10 * tol)
1025
- assert_np_equal(v32.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 2], tol=10 * tol)
1026
-
1027
- assert_np_equal(v40.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 0], tol=10 * tol)
1028
- assert_np_equal(v41.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 1], tol=10 * tol)
1029
- assert_np_equal(v42.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 2], tol=10 * tol)
1030
- assert_np_equal(v43.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 3], tol=10 * tol)
1031
-
1032
- assert_np_equal(v50.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 0], tol=10 * tol)
1033
- assert_np_equal(v51.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 1], tol=10 * tol)
1034
- assert_np_equal(v52.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 2], tol=10 * tol)
1035
- assert_np_equal(v53.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 3], tol=10 * tol)
1036
- assert_np_equal(v54.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 4], tol=10 * tol)
1037
-
1038
- incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1039
-
1040
- if dtype in np_float_types:
1041
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43]):
1042
- tape.backward(loss=l)
1043
- sgrad = tape.gradients[s].numpy()[0]
1044
- assert_np_equal(sgrad, 2 * incmps[i], tol=10 * tol)
1045
- allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4]])
1046
- expected_grads = np.zeros_like(allgrads)
1047
- expected_grads[i] = s.numpy()[0] * 2
1048
- assert_np_equal(allgrads, expected_grads, tol=10 * tol)
1049
- tape.zero()
1050
-
1051
-
1052
- def test_cw_multiplication(test, device, dtype, register_kernels=False):
1053
- rng = np.random.default_rng(123)
1054
-
1055
- tol = {
1056
- np.float16: 5.0e-3,
1057
- np.float32: 1.0e-6,
1058
- np.float64: 1.0e-8,
1059
- }.get(dtype, 0)
1060
-
1061
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1062
- vec2 = wp.types.vector(length=2, dtype=wptype)
1063
- vec3 = wp.types.vector(length=3, dtype=wptype)
1064
- vec4 = wp.types.vector(length=4, dtype=wptype)
1065
- vec5 = wp.types.vector(length=5, dtype=wptype)
1066
-
1067
- def check_cw_mul(
1068
- s2: wp.array(dtype=vec2),
1069
- s3: wp.array(dtype=vec3),
1070
- s4: wp.array(dtype=vec4),
1071
- s5: wp.array(dtype=vec5),
1072
- v2: wp.array(dtype=vec2),
1073
- v3: wp.array(dtype=vec3),
1074
- v4: wp.array(dtype=vec4),
1075
- v5: wp.array(dtype=vec5),
1076
- v20: wp.array(dtype=wptype),
1077
- v21: wp.array(dtype=wptype),
1078
- v30: wp.array(dtype=wptype),
1079
- v31: wp.array(dtype=wptype),
1080
- v32: wp.array(dtype=wptype),
1081
- v40: wp.array(dtype=wptype),
1082
- v41: wp.array(dtype=wptype),
1083
- v42: wp.array(dtype=wptype),
1084
- v43: wp.array(dtype=wptype),
1085
- v50: wp.array(dtype=wptype),
1086
- v51: wp.array(dtype=wptype),
1087
- v52: wp.array(dtype=wptype),
1088
- v53: wp.array(dtype=wptype),
1089
- v54: wp.array(dtype=wptype),
1090
- ):
1091
- v2result = wp.cw_mul(s2[0], v2[0])
1092
- v3result = wp.cw_mul(s3[0], v3[0])
1093
- v4result = wp.cw_mul(s4[0], v4[0])
1094
- v5result = wp.cw_mul(s5[0], v5[0])
1095
-
1096
- v20[0] = wptype(2) * v2result[0]
1097
- v21[0] = wptype(2) * v2result[1]
1098
-
1099
- v30[0] = wptype(2) * v3result[0]
1100
- v31[0] = wptype(2) * v3result[1]
1101
- v32[0] = wptype(2) * v3result[2]
1102
-
1103
- v40[0] = wptype(2) * v4result[0]
1104
- v41[0] = wptype(2) * v4result[1]
1105
- v42[0] = wptype(2) * v4result[2]
1106
- v43[0] = wptype(2) * v4result[3]
1107
-
1108
- v50[0] = wptype(2) * v5result[0]
1109
- v51[0] = wptype(2) * v5result[1]
1110
- v52[0] = wptype(2) * v5result[2]
1111
- v53[0] = wptype(2) * v5result[3]
1112
- v54[0] = wptype(2) * v5result[4]
1113
-
1114
- kernel = getkernel(check_cw_mul, suffix=dtype.__name__)
1115
-
1116
- if register_kernels:
1117
- return
1118
-
1119
- s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1120
- s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1121
- s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1122
- s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1123
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1124
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1125
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1126
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1127
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1128
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1129
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1130
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1131
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1132
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1133
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1134
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1135
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1136
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1137
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1138
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1139
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1140
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1141
- tape = wp.Tape()
1142
- with tape:
1143
- wp.launch(
1144
- kernel,
1145
- dim=1,
1146
- inputs=[
1147
- s2,
1148
- s3,
1149
- s4,
1150
- s5,
1151
- v2,
1152
- v3,
1153
- v4,
1154
- v5,
1155
- ],
1156
- outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1157
- device=device,
1158
- )
1159
-
1160
- assert_np_equal(v20.numpy()[0], 2 * s2.numpy()[0, 0] * v2.numpy()[0, 0], tol=10 * tol)
1161
- assert_np_equal(v21.numpy()[0], 2 * s2.numpy()[0, 1] * v2.numpy()[0, 1], tol=10 * tol)
1162
-
1163
- assert_np_equal(v30.numpy()[0], 2 * s3.numpy()[0, 0] * v3.numpy()[0, 0], tol=10 * tol)
1164
- assert_np_equal(v31.numpy()[0], 2 * s3.numpy()[0, 1] * v3.numpy()[0, 1], tol=10 * tol)
1165
- assert_np_equal(v32.numpy()[0], 2 * s3.numpy()[0, 2] * v3.numpy()[0, 2], tol=10 * tol)
1166
-
1167
- assert_np_equal(v40.numpy()[0], 2 * s4.numpy()[0, 0] * v4.numpy()[0, 0], tol=10 * tol)
1168
- assert_np_equal(v41.numpy()[0], 2 * s4.numpy()[0, 1] * v4.numpy()[0, 1], tol=10 * tol)
1169
- assert_np_equal(v42.numpy()[0], 2 * s4.numpy()[0, 2] * v4.numpy()[0, 2], tol=10 * tol)
1170
- assert_np_equal(v43.numpy()[0], 2 * s4.numpy()[0, 3] * v4.numpy()[0, 3], tol=10 * tol)
1171
-
1172
- assert_np_equal(v50.numpy()[0], 2 * s5.numpy()[0, 0] * v5.numpy()[0, 0], tol=10 * tol)
1173
- assert_np_equal(v51.numpy()[0], 2 * s5.numpy()[0, 1] * v5.numpy()[0, 1], tol=10 * tol)
1174
- assert_np_equal(v52.numpy()[0], 2 * s5.numpy()[0, 2] * v5.numpy()[0, 2], tol=10 * tol)
1175
- assert_np_equal(v53.numpy()[0], 2 * s5.numpy()[0, 3] * v5.numpy()[0, 3], tol=10 * tol)
1176
- assert_np_equal(v54.numpy()[0], 2 * s5.numpy()[0, 4] * v5.numpy()[0, 4], tol=10 * tol)
1177
-
1178
- incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1179
- scmps = np.concatenate([v.numpy()[0] for v in [s2, s3, s4, s5]])
1180
-
1181
- if dtype in np_float_types:
1182
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1183
- tape.backward(loss=l)
1184
- sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1185
- expected_grads = np.zeros_like(sgrads)
1186
- expected_grads[i] = incmps[i] * 2
1187
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1188
-
1189
- allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1190
- expected_grads = np.zeros_like(allgrads)
1191
- expected_grads[i] = scmps[i] * 2
1192
- assert_np_equal(allgrads, expected_grads, tol=10 * tol)
1193
-
1194
- tape.zero()
1195
-
1196
-
1197
- def test_scalar_division(test, device, dtype, register_kernels=False):
1198
- rng = np.random.default_rng(123)
1199
-
1200
- tol = {
1201
- np.float16: 5.0e-3,
1202
- np.float32: 1.0e-6,
1203
- np.float64: 1.0e-8,
1204
- }.get(dtype, 0)
1205
-
1206
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1207
- vec2 = wp.types.vector(length=2, dtype=wptype)
1208
- vec3 = wp.types.vector(length=3, dtype=wptype)
1209
- vec4 = wp.types.vector(length=4, dtype=wptype)
1210
- vec5 = wp.types.vector(length=5, dtype=wptype)
1211
-
1212
- def check_div(
1213
- s: wp.array(dtype=wptype),
1214
- v2: wp.array(dtype=vec2),
1215
- v3: wp.array(dtype=vec3),
1216
- v4: wp.array(dtype=vec4),
1217
- v5: wp.array(dtype=vec5),
1218
- v20: wp.array(dtype=wptype),
1219
- v21: wp.array(dtype=wptype),
1220
- v30: wp.array(dtype=wptype),
1221
- v31: wp.array(dtype=wptype),
1222
- v32: wp.array(dtype=wptype),
1223
- v40: wp.array(dtype=wptype),
1224
- v41: wp.array(dtype=wptype),
1225
- v42: wp.array(dtype=wptype),
1226
- v43: wp.array(dtype=wptype),
1227
- v50: wp.array(dtype=wptype),
1228
- v51: wp.array(dtype=wptype),
1229
- v52: wp.array(dtype=wptype),
1230
- v53: wp.array(dtype=wptype),
1231
- v54: wp.array(dtype=wptype),
1232
- ):
1233
- v2result = v2[0] / s[0]
1234
- v3result = v3[0] / s[0]
1235
- v4result = v4[0] / s[0]
1236
- v5result = v5[0] / s[0]
1237
-
1238
- v20[0] = wptype(2) * v2result[0]
1239
- v21[0] = wptype(2) * v2result[1]
1240
-
1241
- v30[0] = wptype(2) * v3result[0]
1242
- v31[0] = wptype(2) * v3result[1]
1243
- v32[0] = wptype(2) * v3result[2]
1244
-
1245
- v40[0] = wptype(2) * v4result[0]
1246
- v41[0] = wptype(2) * v4result[1]
1247
- v42[0] = wptype(2) * v4result[2]
1248
- v43[0] = wptype(2) * v4result[3]
1249
-
1250
- v50[0] = wptype(2) * v5result[0]
1251
- v51[0] = wptype(2) * v5result[1]
1252
- v52[0] = wptype(2) * v5result[2]
1253
- v53[0] = wptype(2) * v5result[3]
1254
- v54[0] = wptype(2) * v5result[4]
1255
-
1256
- kernel = getkernel(check_div, suffix=dtype.__name__)
1257
-
1258
- if register_kernels:
1259
- return
1260
-
1261
- s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
1262
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1263
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1264
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1265
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1266
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1267
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1268
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1269
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1270
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1271
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1272
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1273
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1274
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1275
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1276
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1277
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1278
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1279
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1280
- tape = wp.Tape()
1281
- with tape:
1282
- wp.launch(
1283
- kernel,
1284
- dim=1,
1285
- inputs=[
1286
- s,
1287
- v2,
1288
- v3,
1289
- v4,
1290
- v5,
1291
- ],
1292
- outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1293
- device=device,
1294
- )
1295
-
1296
- if dtype in np_int_types:
1297
- assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] // (s.numpy()[0])), tol=tol)
1298
- assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] // (s.numpy()[0])), tol=tol)
1299
-
1300
- assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1301
- assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1302
- assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1303
-
1304
- assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1305
- assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1306
- assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1307
- assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] // (s.numpy()[0])), tol=10 * tol)
1308
-
1309
- assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1310
- assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1311
- assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1312
- assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] // (s.numpy()[0])), tol=10 * tol)
1313
- assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] // (s.numpy()[0])), tol=10 * tol)
1314
-
1315
- else:
1316
- assert_np_equal(v20.numpy()[0], 2 * v2.numpy()[0, 0] / (s.numpy()[0]), tol=tol)
1317
- assert_np_equal(v21.numpy()[0], 2 * v2.numpy()[0, 1] / (s.numpy()[0]), tol=tol)
1318
-
1319
- assert_np_equal(v30.numpy()[0], 2 * v3.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1320
- assert_np_equal(v31.numpy()[0], 2 * v3.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1321
- assert_np_equal(v32.numpy()[0], 2 * v3.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1322
-
1323
- assert_np_equal(v40.numpy()[0], 2 * v4.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1324
- assert_np_equal(v41.numpy()[0], 2 * v4.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1325
- assert_np_equal(v42.numpy()[0], 2 * v4.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1326
- assert_np_equal(v43.numpy()[0], 2 * v4.numpy()[0, 3] / (s.numpy()[0]), tol=10 * tol)
1327
-
1328
- assert_np_equal(v50.numpy()[0], 2 * v5.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1329
- assert_np_equal(v51.numpy()[0], 2 * v5.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1330
- assert_np_equal(v52.numpy()[0], 2 * v5.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1331
- assert_np_equal(v53.numpy()[0], 2 * v5.numpy()[0, 3] / (s.numpy()[0]), tol=10 * tol)
1332
- assert_np_equal(v54.numpy()[0], 2 * v5.numpy()[0, 4] / (s.numpy()[0]), tol=10 * tol)
1333
-
1334
- incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1335
-
1336
- if dtype in np_float_types:
1337
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1338
- tape.backward(loss=l)
1339
- sgrad = tape.gradients[s].numpy()[0]
1340
-
1341
- # d/ds v/s = -v/s^2
1342
- assert_np_equal(sgrad, -2 * incmps[i] / (s.numpy()[0] * s.numpy()[0]), tol=10 * tol)
1343
-
1344
- allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1345
- expected_grads = np.zeros_like(allgrads)
1346
- expected_grads[i] = 2 / s.numpy()[0]
1347
-
1348
- # d/dv v/s = 1/s
1349
- assert_np_equal(allgrads, expected_grads, tol=tol)
1350
- tape.zero()
1351
-
1352
-
1353
- def test_cw_division(test, device, dtype, register_kernels=False):
1354
- rng = np.random.default_rng(123)
1355
-
1356
- tol = {
1357
- np.float16: 1.0e-2,
1358
- np.float32: 1.0e-6,
1359
- np.float64: 1.0e-8,
1360
- }.get(dtype, 0)
1361
-
1362
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1363
- vec2 = wp.types.vector(length=2, dtype=wptype)
1364
- vec3 = wp.types.vector(length=3, dtype=wptype)
1365
- vec4 = wp.types.vector(length=4, dtype=wptype)
1366
- vec5 = wp.types.vector(length=5, dtype=wptype)
1367
-
1368
- def check_cw_div(
1369
- s2: wp.array(dtype=vec2),
1370
- s3: wp.array(dtype=vec3),
1371
- s4: wp.array(dtype=vec4),
1372
- s5: wp.array(dtype=vec5),
1373
- v2: wp.array(dtype=vec2),
1374
- v3: wp.array(dtype=vec3),
1375
- v4: wp.array(dtype=vec4),
1376
- v5: wp.array(dtype=vec5),
1377
- v20: wp.array(dtype=wptype),
1378
- v21: wp.array(dtype=wptype),
1379
- v30: wp.array(dtype=wptype),
1380
- v31: wp.array(dtype=wptype),
1381
- v32: wp.array(dtype=wptype),
1382
- v40: wp.array(dtype=wptype),
1383
- v41: wp.array(dtype=wptype),
1384
- v42: wp.array(dtype=wptype),
1385
- v43: wp.array(dtype=wptype),
1386
- v50: wp.array(dtype=wptype),
1387
- v51: wp.array(dtype=wptype),
1388
- v52: wp.array(dtype=wptype),
1389
- v53: wp.array(dtype=wptype),
1390
- v54: wp.array(dtype=wptype),
1391
- ):
1392
- v2result = wp.cw_div(v2[0], s2[0])
1393
- v3result = wp.cw_div(v3[0], s3[0])
1394
- v4result = wp.cw_div(v4[0], s4[0])
1395
- v5result = wp.cw_div(v5[0], s5[0])
1396
-
1397
- v20[0] = wptype(2) * v2result[0]
1398
- v21[0] = wptype(2) * v2result[1]
1399
-
1400
- v30[0] = wptype(2) * v3result[0]
1401
- v31[0] = wptype(2) * v3result[1]
1402
- v32[0] = wptype(2) * v3result[2]
1403
-
1404
- v40[0] = wptype(2) * v4result[0]
1405
- v41[0] = wptype(2) * v4result[1]
1406
- v42[0] = wptype(2) * v4result[2]
1407
- v43[0] = wptype(2) * v4result[3]
1408
-
1409
- v50[0] = wptype(2) * v5result[0]
1410
- v51[0] = wptype(2) * v5result[1]
1411
- v52[0] = wptype(2) * v5result[2]
1412
- v53[0] = wptype(2) * v5result[3]
1413
- v54[0] = wptype(2) * v5result[4]
1414
-
1415
- kernel = getkernel(check_cw_div, suffix=dtype.__name__)
1416
-
1417
- if register_kernels:
1418
- return
1419
-
1420
- s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1421
- s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1422
- s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1423
- s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1424
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1425
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1426
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1427
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1428
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1429
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1430
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1431
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1432
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1433
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1434
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1435
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1436
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1437
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1438
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1439
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1440
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1441
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1442
- tape = wp.Tape()
1443
- with tape:
1444
- wp.launch(
1445
- kernel,
1446
- dim=1,
1447
- inputs=[
1448
- s2,
1449
- s3,
1450
- s4,
1451
- s5,
1452
- v2,
1453
- v3,
1454
- v4,
1455
- v5,
1456
- ],
1457
- outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1458
- device=device,
1459
- )
1460
-
1461
- if dtype in np_int_types:
1462
- assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] // s2.numpy()[0, 0]), tol=tol)
1463
- assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] // s2.numpy()[0, 1]), tol=tol)
1464
-
1465
- assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] // s3.numpy()[0, 0]), tol=tol)
1466
- assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] // s3.numpy()[0, 1]), tol=tol)
1467
- assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] // s3.numpy()[0, 2]), tol=tol)
1468
-
1469
- assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] // s4.numpy()[0, 0]), tol=tol)
1470
- assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] // s4.numpy()[0, 1]), tol=tol)
1471
- assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] // s4.numpy()[0, 2]), tol=tol)
1472
- assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] // s4.numpy()[0, 3]), tol=tol)
1473
-
1474
- assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] // s5.numpy()[0, 0]), tol=tol)
1475
- assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] // s5.numpy()[0, 1]), tol=tol)
1476
- assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] // s5.numpy()[0, 2]), tol=tol)
1477
- assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] // s5.numpy()[0, 3]), tol=tol)
1478
- assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] // s5.numpy()[0, 4]), tol=tol)
1479
- else:
1480
- assert_np_equal(v20.numpy()[0], 2 * v2.numpy()[0, 0] / s2.numpy()[0, 0], tol=tol)
1481
- assert_np_equal(v21.numpy()[0], 2 * v2.numpy()[0, 1] / s2.numpy()[0, 1], tol=tol)
1482
-
1483
- assert_np_equal(v30.numpy()[0], 2 * v3.numpy()[0, 0] / s3.numpy()[0, 0], tol=tol)
1484
- assert_np_equal(v31.numpy()[0], 2 * v3.numpy()[0, 1] / s3.numpy()[0, 1], tol=tol)
1485
- assert_np_equal(v32.numpy()[0], 2 * v3.numpy()[0, 2] / s3.numpy()[0, 2], tol=tol)
1486
-
1487
- assert_np_equal(v40.numpy()[0], 2 * v4.numpy()[0, 0] / s4.numpy()[0, 0], tol=tol)
1488
- assert_np_equal(v41.numpy()[0], 2 * v4.numpy()[0, 1] / s4.numpy()[0, 1], tol=tol)
1489
- assert_np_equal(v42.numpy()[0], 2 * v4.numpy()[0, 2] / s4.numpy()[0, 2], tol=tol)
1490
- assert_np_equal(v43.numpy()[0], 2 * v4.numpy()[0, 3] / s4.numpy()[0, 3], tol=tol)
1491
-
1492
- assert_np_equal(v50.numpy()[0], 2 * v5.numpy()[0, 0] / s5.numpy()[0, 0], tol=tol)
1493
- assert_np_equal(v51.numpy()[0], 2 * v5.numpy()[0, 1] / s5.numpy()[0, 1], tol=tol)
1494
- assert_np_equal(v52.numpy()[0], 2 * v5.numpy()[0, 2] / s5.numpy()[0, 2], tol=tol)
1495
- assert_np_equal(v53.numpy()[0], 2 * v5.numpy()[0, 3] / s5.numpy()[0, 3], tol=tol)
1496
- assert_np_equal(v54.numpy()[0], 2 * v5.numpy()[0, 4] / s5.numpy()[0, 4], tol=tol)
1497
-
1498
- if dtype in np_float_types:
1499
- incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1500
- scmps = np.concatenate([v.numpy()[0] for v in [s2, s3, s4, s5]])
1501
-
1502
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1503
- tape.backward(loss=l)
1504
- sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1505
- expected_grads = np.zeros_like(sgrads)
1506
-
1507
- # d/ds v/s = -v/s^2
1508
- expected_grads[i] = -incmps[i] * 2 / (scmps[i] * scmps[i])
1509
- assert_np_equal(sgrads, expected_grads, tol=20 * tol)
1510
-
1511
- allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1512
- expected_grads = np.zeros_like(allgrads)
1513
-
1514
- # d/dv v/s = 1/s
1515
- expected_grads[i] = 2 / scmps[i]
1516
- assert_np_equal(allgrads, expected_grads, tol=tol)
1517
-
1518
- tape.zero()
1519
-
1520
-
1521
- def test_addition(test, device, dtype, register_kernels=False):
1522
- rng = np.random.default_rng(123)
1523
-
1524
- tol = {
1525
- np.float16: 5.0e-3,
1526
- np.float32: 1.0e-6,
1527
- np.float64: 1.0e-8,
1528
- }.get(dtype, 0)
1529
-
1530
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1531
- vec2 = wp.types.vector(length=2, dtype=wptype)
1532
- vec3 = wp.types.vector(length=3, dtype=wptype)
1533
- vec4 = wp.types.vector(length=4, dtype=wptype)
1534
- vec5 = wp.types.vector(length=5, dtype=wptype)
1535
-
1536
- def check_add(
1537
- s2: wp.array(dtype=vec2),
1538
- s3: wp.array(dtype=vec3),
1539
- s4: wp.array(dtype=vec4),
1540
- s5: wp.array(dtype=vec5),
1541
- v2: wp.array(dtype=vec2),
1542
- v3: wp.array(dtype=vec3),
1543
- v4: wp.array(dtype=vec4),
1544
- v5: wp.array(dtype=vec5),
1545
- v20: wp.array(dtype=wptype),
1546
- v21: wp.array(dtype=wptype),
1547
- v30: wp.array(dtype=wptype),
1548
- v31: wp.array(dtype=wptype),
1549
- v32: wp.array(dtype=wptype),
1550
- v40: wp.array(dtype=wptype),
1551
- v41: wp.array(dtype=wptype),
1552
- v42: wp.array(dtype=wptype),
1553
- v43: wp.array(dtype=wptype),
1554
- v50: wp.array(dtype=wptype),
1555
- v51: wp.array(dtype=wptype),
1556
- v52: wp.array(dtype=wptype),
1557
- v53: wp.array(dtype=wptype),
1558
- v54: wp.array(dtype=wptype),
1559
- ):
1560
- v2result = v2[0] + s2[0]
1561
- v3result = v3[0] + s3[0]
1562
- v4result = v4[0] + s4[0]
1563
- v5result = v5[0] + s5[0]
1564
-
1565
- v20[0] = wptype(2) * v2result[0]
1566
- v21[0] = wptype(2) * v2result[1]
1567
-
1568
- v30[0] = wptype(2) * v3result[0]
1569
- v31[0] = wptype(2) * v3result[1]
1570
- v32[0] = wptype(2) * v3result[2]
1571
-
1572
- v40[0] = wptype(2) * v4result[0]
1573
- v41[0] = wptype(2) * v4result[1]
1574
- v42[0] = wptype(2) * v4result[2]
1575
- v43[0] = wptype(2) * v4result[3]
1576
-
1577
- v50[0] = wptype(2) * v5result[0]
1578
- v51[0] = wptype(2) * v5result[1]
1579
- v52[0] = wptype(2) * v5result[2]
1580
- v53[0] = wptype(2) * v5result[3]
1581
- v54[0] = wptype(2) * v5result[4]
1582
-
1583
- kernel = getkernel(check_add, suffix=dtype.__name__)
1584
-
1585
- if register_kernels:
1586
- return
1587
-
1588
- s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1589
- s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1590
- s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1591
- s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1592
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1593
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1594
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1595
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1596
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1597
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1598
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1599
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1600
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1601
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1602
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1603
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1604
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1605
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1606
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1607
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1608
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1609
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1610
- tape = wp.Tape()
1611
- with tape:
1612
- wp.launch(
1613
- kernel,
1614
- dim=1,
1615
- inputs=[
1616
- s2,
1617
- s3,
1618
- s4,
1619
- s5,
1620
- v2,
1621
- v3,
1622
- v4,
1623
- v5,
1624
- ],
1625
- outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1626
- device=device,
1627
- )
1628
-
1629
- assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] + s2.numpy()[0, 0]), tol=tol)
1630
- assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] + s2.numpy()[0, 1]), tol=tol)
1631
-
1632
- assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] + s3.numpy()[0, 0]), tol=tol)
1633
- assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] + s3.numpy()[0, 1]), tol=tol)
1634
- assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] + s3.numpy()[0, 2]), tol=tol)
1635
-
1636
- assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] + s4.numpy()[0, 0]), tol=tol)
1637
- assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] + s4.numpy()[0, 1]), tol=tol)
1638
- assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] + s4.numpy()[0, 2]), tol=tol)
1639
- assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] + s4.numpy()[0, 3]), tol=tol)
1640
-
1641
- assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] + s5.numpy()[0, 0]), tol=tol)
1642
- assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] + s5.numpy()[0, 1]), tol=tol)
1643
- assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] + s5.numpy()[0, 2]), tol=tol)
1644
- assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] + s5.numpy()[0, 3]), tol=tol)
1645
- assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] + s5.numpy()[0, 4]), tol=2 * tol)
1646
-
1647
- if dtype in np_float_types:
1648
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1649
- tape.backward(loss=l)
1650
- sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1651
- expected_grads = np.zeros_like(sgrads)
1652
-
1653
- expected_grads[i] = 2
1654
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1655
-
1656
- allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1657
- assert_np_equal(allgrads, expected_grads, tol=tol)
1658
-
1659
- tape.zero()
1660
-
1661
-
1662
- def test_dotproduct(test, device, dtype, register_kernels=False):
1663
- rng = np.random.default_rng(123)
1664
-
1665
- tol = {
1666
- np.float16: 1.0e-2,
1667
- np.float32: 1.0e-6,
1668
- np.float64: 1.0e-8,
1669
- }.get(dtype, 0)
1670
-
1671
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1672
- vec2 = wp.types.vector(length=2, dtype=wptype)
1673
- vec3 = wp.types.vector(length=3, dtype=wptype)
1674
- vec4 = wp.types.vector(length=4, dtype=wptype)
1675
- vec5 = wp.types.vector(length=5, dtype=wptype)
1676
-
1677
- def check_dot(
1678
- s2: wp.array(dtype=vec2),
1679
- s3: wp.array(dtype=vec3),
1680
- s4: wp.array(dtype=vec4),
1681
- s5: wp.array(dtype=vec5),
1682
- v2: wp.array(dtype=vec2),
1683
- v3: wp.array(dtype=vec3),
1684
- v4: wp.array(dtype=vec4),
1685
- v5: wp.array(dtype=vec5),
1686
- dot2: wp.array(dtype=wptype),
1687
- dot3: wp.array(dtype=wptype),
1688
- dot4: wp.array(dtype=wptype),
1689
- dot5: wp.array(dtype=wptype),
1690
- ):
1691
- dot2[0] = wptype(2) * wp.dot(v2[0], s2[0])
1692
- dot3[0] = wptype(2) * wp.dot(v3[0], s3[0])
1693
- dot4[0] = wptype(2) * wp.dot(v4[0], s4[0])
1694
- dot5[0] = wptype(2) * wp.dot(v5[0], s5[0])
1695
-
1696
- kernel = getkernel(check_dot, suffix=dtype.__name__)
1697
-
1698
- if register_kernels:
1699
- return
1700
-
1701
- s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1702
- s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1703
- s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1704
- s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1705
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1706
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1707
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1708
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1709
- dot2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1710
- dot3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1711
- dot4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1712
- dot5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1713
- tape = wp.Tape()
1714
- with tape:
1715
- wp.launch(
1716
- kernel,
1717
- dim=1,
1718
- inputs=[
1719
- s2,
1720
- s3,
1721
- s4,
1722
- s5,
1723
- v2,
1724
- v3,
1725
- v4,
1726
- v5,
1727
- ],
1728
- outputs=[dot2, dot3, dot4, dot5],
1729
- device=device,
1730
- )
1731
-
1732
- assert_np_equal(dot2.numpy()[0], 2.0 * (v2.numpy() * s2.numpy()).sum(), tol=10 * tol)
1733
- assert_np_equal(dot3.numpy()[0], 2.0 * (v3.numpy() * s3.numpy()).sum(), tol=10 * tol)
1734
- assert_np_equal(dot4.numpy()[0], 2.0 * (v4.numpy() * s4.numpy()).sum(), tol=10 * tol)
1735
- assert_np_equal(dot5.numpy()[0], 2.0 * (v5.numpy() * s5.numpy()).sum(), tol=10 * tol)
1736
-
1737
- if dtype in np_float_types:
1738
- tape.backward(loss=dot2)
1739
- sgrads = tape.gradients[s2].numpy()[0]
1740
- expected_grads = 2.0 * v2.numpy()[0]
1741
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1742
-
1743
- vgrads = tape.gradients[v2].numpy()[0]
1744
- expected_grads = 2.0 * s2.numpy()[0]
1745
- assert_np_equal(vgrads, expected_grads, tol=tol)
1746
-
1747
- tape.zero()
1748
-
1749
- tape.backward(loss=dot3)
1750
- sgrads = tape.gradients[s3].numpy()[0]
1751
- expected_grads = 2.0 * v3.numpy()[0]
1752
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1753
-
1754
- vgrads = tape.gradients[v3].numpy()[0]
1755
- expected_grads = 2.0 * s3.numpy()[0]
1756
- assert_np_equal(vgrads, expected_grads, tol=tol)
1757
-
1758
- tape.zero()
1759
-
1760
- tape.backward(loss=dot4)
1761
- sgrads = tape.gradients[s4].numpy()[0]
1762
- expected_grads = 2.0 * v4.numpy()[0]
1763
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1764
-
1765
- vgrads = tape.gradients[v4].numpy()[0]
1766
- expected_grads = 2.0 * s4.numpy()[0]
1767
- assert_np_equal(vgrads, expected_grads, tol=tol)
1768
-
1769
- tape.zero()
1770
-
1771
- tape.backward(loss=dot5)
1772
- sgrads = tape.gradients[s5].numpy()[0]
1773
- expected_grads = 2.0 * v5.numpy()[0]
1774
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1775
-
1776
- vgrads = tape.gradients[v5].numpy()[0]
1777
- expected_grads = 2.0 * s5.numpy()[0]
1778
- assert_np_equal(vgrads, expected_grads, tol=10 * tol)
1779
-
1780
- tape.zero()
1781
-
1782
-
1783
- def test_equivalent_types(test, device, dtype, register_kernels=False):
1784
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1785
-
1786
- # vector types
1787
- vec2 = wp.types.vector(length=2, dtype=wptype)
1788
- vec3 = wp.types.vector(length=3, dtype=wptype)
1789
- vec4 = wp.types.vector(length=4, dtype=wptype)
1790
- vec5 = wp.types.vector(length=5, dtype=wptype)
1791
-
1792
- # vector types equivalent to the above
1793
- vec2_equiv = wp.types.vector(length=2, dtype=wptype)
1794
- vec3_equiv = wp.types.vector(length=3, dtype=wptype)
1795
- vec4_equiv = wp.types.vector(length=4, dtype=wptype)
1796
- vec5_equiv = wp.types.vector(length=5, dtype=wptype)
1797
-
1798
- # declare kernel with original types
1799
- def check_equivalence(
1800
- v2: vec2,
1801
- v3: vec3,
1802
- v4: vec4,
1803
- v5: vec5,
1804
- ):
1805
- wp.expect_eq(v2, vec2(wptype(1), wptype(2)))
1806
- wp.expect_eq(v3, vec3(wptype(1), wptype(2), wptype(3)))
1807
- wp.expect_eq(v4, vec4(wptype(1), wptype(2), wptype(3), wptype(4)))
1808
- wp.expect_eq(v5, vec5(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1809
-
1810
- wp.expect_eq(v2, vec2_equiv(wptype(1), wptype(2)))
1811
- wp.expect_eq(v3, vec3_equiv(wptype(1), wptype(2), wptype(3)))
1812
- wp.expect_eq(v4, vec4_equiv(wptype(1), wptype(2), wptype(3), wptype(4)))
1813
- wp.expect_eq(v5, vec5_equiv(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1814
-
1815
- kernel = getkernel(check_equivalence, suffix=dtype.__name__)
1816
-
1817
- if register_kernels:
1818
- return
1819
-
1820
- # call kernel with equivalent types
1821
- v2 = vec2_equiv(1, 2)
1822
- v3 = vec3_equiv(1, 2, 3)
1823
- v4 = vec4_equiv(1, 2, 3, 4)
1824
- v5 = vec5_equiv(1, 2, 3, 4, 5)
1825
-
1826
- wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5], device=device)
1827
-
1828
-
1829
- def test_conversions(test, device, dtype, register_kernels=False):
1830
- def check_vectors_equal(
1831
- v0: wp.vec3,
1832
- v1: wp.vec3,
1833
- v2: wp.vec3,
1834
- v3: wp.vec3,
1835
- ):
1836
- wp.expect_eq(v1, v0)
1837
- wp.expect_eq(v2, v0)
1838
- wp.expect_eq(v3, v0)
1839
-
1840
- kernel = getkernel(check_vectors_equal, suffix=dtype.__name__)
1841
-
1842
- if register_kernels:
1843
- return
1844
-
1845
- v0 = wp.vec3(1, 2, 3)
1846
-
1847
- # test explicit conversions - constructing vectors from different containers
1848
- v1 = wp.vec3((1, 2, 3))
1849
- v2 = wp.vec3([1, 2, 3])
1850
- v3 = wp.vec3(np.array([1, 2, 3], dtype=dtype))
1851
-
1852
- wp.launch(kernel, dim=1, inputs=[v0, v1, v2, v3], device=device)
1853
-
1854
- # test implicit conversions - passing different containers as vectors to wp.launch()
1855
- v1 = (1, 2, 3)
1856
- v2 = [1, 2, 3]
1857
- v3 = np.array([1, 2, 3], dtype=dtype)
1858
-
1859
- wp.launch(kernel, dim=1, inputs=[v0, v1, v2, v3], device=device)
1860
-
1861
-
1862
- def test_constants(test, device, dtype, register_kernels=False):
1863
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1864
- vec2 = wp.types.vector(length=2, dtype=wptype)
1865
- vec3 = wp.types.vector(length=3, dtype=wptype)
1866
- vec4 = wp.types.vector(length=4, dtype=wptype)
1867
- vec5 = wp.types.vector(length=5, dtype=wptype)
1868
-
1869
- cv2 = wp.constant(vec2(1, 2))
1870
- cv3 = wp.constant(vec3(1, 2, 3))
1871
- cv4 = wp.constant(vec4(1, 2, 3, 4))
1872
- cv5 = wp.constant(vec5(1, 2, 3, 4, 5))
1873
-
1874
- def check_vector_constants():
1875
- wp.expect_eq(cv2, vec2(wptype(1), wptype(2)))
1876
- wp.expect_eq(cv3, vec3(wptype(1), wptype(2), wptype(3)))
1877
- wp.expect_eq(cv4, vec4(wptype(1), wptype(2), wptype(3), wptype(4)))
1878
- wp.expect_eq(cv5, vec5(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1879
-
1880
- kernel = getkernel(check_vector_constants, suffix=dtype.__name__)
1881
-
1882
- if register_kernels:
1883
- return
1884
-
1885
- wp.launch(kernel, dim=1, inputs=[])
1886
-
1887
-
1888
- def test_minmax(test, device, dtype, register_kernels=False):
1889
- rng = np.random.default_rng(123)
1890
-
1891
- # \TODO: not quite sure why, but the numbers are off for 16 bit float
1892
- # on the cpu (but not cuda). This is probably just the sketchy float16
1893
- # arithmetic I implemented to get all this stuff working, so
1894
- # hopefully that can be fixed when we do that correctly.
1895
- tol = {
1896
- np.float16: 1.0e-2,
1897
- }.get(dtype, 0)
1898
-
1899
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1900
- vec2 = wp.types.vector(length=2, dtype=wptype)
1901
- vec3 = wp.types.vector(length=3, dtype=wptype)
1902
- vec4 = wp.types.vector(length=4, dtype=wptype)
1903
- vec5 = wp.types.vector(length=5, dtype=wptype)
1904
-
1905
- # \TODO: Also not quite sure why: this kernel compiles incredibly
1906
- # slowly though...
1907
- def check_vec_min_max(
1908
- a: wp.array(dtype=wptype, ndim=2),
1909
- b: wp.array(dtype=wptype, ndim=2),
1910
- mins: wp.array(dtype=wptype, ndim=2),
1911
- maxs: wp.array(dtype=wptype, ndim=2),
1912
- ):
1913
- for i in range(10):
1914
- # multiplying by 2 so we've got something to backpropagate:
1915
- a2read = vec2(a[i, 0], a[i, 1])
1916
- b2read = vec2(b[i, 0], b[i, 1])
1917
- c2 = wptype(2) * wp.min(a2read, b2read)
1918
- d2 = wptype(2) * wp.max(a2read, b2read)
1919
-
1920
- a3read = vec3(a[i, 2], a[i, 3], a[i, 4])
1921
- b3read = vec3(b[i, 2], b[i, 3], b[i, 4])
1922
- c3 = wptype(2) * wp.min(a3read, b3read)
1923
- d3 = wptype(2) * wp.max(a3read, b3read)
1924
-
1925
- a4read = vec4(a[i, 5], a[i, 6], a[i, 7], a[i, 8])
1926
- b4read = vec4(b[i, 5], b[i, 6], b[i, 7], b[i, 8])
1927
- c4 = wptype(2) * wp.min(a4read, b4read)
1928
- d4 = wptype(2) * wp.max(a4read, b4read)
1929
-
1930
- a5read = vec5(a[i, 9], a[i, 10], a[i, 11], a[i, 12], a[i, 13])
1931
- b5read = vec5(b[i, 9], b[i, 10], b[i, 11], b[i, 12], b[i, 13])
1932
- c5 = wptype(2) * wp.min(a5read, b5read)
1933
- d5 = wptype(2) * wp.max(a5read, b5read)
1934
-
1935
- mins[i, 0] = c2[0]
1936
- mins[i, 1] = c2[1]
1937
-
1938
- mins[i, 2] = c3[0]
1939
- mins[i, 3] = c3[1]
1940
- mins[i, 4] = c3[2]
1941
-
1942
- mins[i, 5] = c4[0]
1943
- mins[i, 6] = c4[1]
1944
- mins[i, 7] = c4[2]
1945
- mins[i, 8] = c4[3]
1946
-
1947
- mins[i, 9] = c5[0]
1948
- mins[i, 10] = c5[1]
1949
- mins[i, 11] = c5[2]
1950
- mins[i, 12] = c5[3]
1951
- mins[i, 13] = c5[4]
1952
-
1953
- maxs[i, 0] = d2[0]
1954
- maxs[i, 1] = d2[1]
1955
-
1956
- maxs[i, 2] = d3[0]
1957
- maxs[i, 3] = d3[1]
1958
- maxs[i, 4] = d3[2]
1959
-
1960
- maxs[i, 5] = d4[0]
1961
- maxs[i, 6] = d4[1]
1962
- maxs[i, 7] = d4[2]
1963
- maxs[i, 8] = d4[3]
1964
-
1965
- maxs[i, 9] = d5[0]
1966
- maxs[i, 10] = d5[1]
1967
- maxs[i, 11] = d5[2]
1968
- maxs[i, 12] = d5[3]
1969
- maxs[i, 13] = d5[4]
1970
-
1971
- kernel = getkernel(check_vec_min_max, suffix=dtype.__name__)
1972
- output_select_kernel = get_select_kernel2(wptype)
1973
-
1974
- if register_kernels:
1975
- return
1976
-
1977
- a = wp.array(randvals(rng, (10, 14), dtype), dtype=wptype, requires_grad=True, device=device)
1978
- b = wp.array(randvals(rng, (10, 14), dtype), dtype=wptype, requires_grad=True, device=device)
1979
-
1980
- mins = wp.zeros((10, 14), dtype=wptype, requires_grad=True, device=device)
1981
- maxs = wp.zeros((10, 14), dtype=wptype, requires_grad=True, device=device)
1982
-
1983
- tape = wp.Tape()
1984
- with tape:
1985
- wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
1986
-
1987
- assert_np_equal(mins.numpy(), 2 * np.minimum(a.numpy(), b.numpy()), tol=tol)
1988
- assert_np_equal(maxs.numpy(), 2 * np.maximum(a.numpy(), b.numpy()), tol=tol)
1989
-
1990
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1991
- if dtype in np_float_types:
1992
- for i in range(10):
1993
- for j in range(14):
1994
- tape = wp.Tape()
1995
- with tape:
1996
- wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
1997
- wp.launch(output_select_kernel, dim=1, inputs=[mins, i, j], outputs=[out], device=device)
1998
-
1999
- tape.backward(loss=out)
2000
- expected = np.zeros_like(a.numpy())
2001
- expected[i, j] = 2 if (a.numpy()[i, j] < b.numpy()[i, j]) else 0
2002
- assert_np_equal(tape.gradients[a].numpy(), expected, tol=tol)
2003
- expected[i, j] = 2 if (b.numpy()[i, j] < a.numpy()[i, j]) else 0
2004
- assert_np_equal(tape.gradients[b].numpy(), expected, tol=tol)
2005
- tape.zero()
2006
-
2007
- tape = wp.Tape()
2008
- with tape:
2009
- wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
2010
- wp.launch(output_select_kernel, dim=1, inputs=[maxs, i, j], outputs=[out], device=device)
2011
-
2012
- tape.backward(loss=out)
2013
- expected = np.zeros_like(a.numpy())
2014
- expected[i, j] = 2 if (a.numpy()[i, j] > b.numpy()[i, j]) else 0
2015
- assert_np_equal(tape.gradients[a].numpy(), expected, tol=tol)
2016
- expected[i, j] = 2 if (b.numpy()[i, j] > a.numpy()[i, j]) else 0
2017
- assert_np_equal(tape.gradients[b].numpy(), expected, tol=tol)
2018
- tape.zero()
2019
-
2020
-
2021
- devices = get_test_devices()
2022
-
2023
-
2024
- class TestVecScalarOps(unittest.TestCase):
2025
- pass
2026
-
2027
-
2028
- for dtype in np_scalar_types:
2029
- add_function_test(TestVecScalarOps, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
2030
- add_function_test(TestVecScalarOps, f"test_components_{dtype.__name__}", test_components, devices=None, dtype=dtype)
2031
- add_function_test(
2032
- TestVecScalarOps, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2033
- )
2034
- add_function_test_register_kernel(
2035
- TestVecScalarOps, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
2036
- )
2037
- add_function_test_register_kernel(
2038
- TestVecScalarOps,
2039
- f"test_anon_type_instance_{dtype.__name__}",
2040
- test_anon_type_instance,
2041
- devices=devices,
2042
- dtype=dtype,
2043
- )
2044
- add_function_test_register_kernel(
2045
- TestVecScalarOps, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2046
- )
2047
- add_function_test_register_kernel(
2048
- TestVecScalarOps, f"test_equality_{dtype.__name__}", test_equality, devices=devices, dtype=dtype
2049
- )
2050
- add_function_test_register_kernel(
2051
- TestVecScalarOps,
2052
- f"test_scalar_multiplication_{dtype.__name__}",
2053
- test_scalar_multiplication,
2054
- devices=devices,
2055
- dtype=dtype,
2056
- )
2057
- add_function_test_register_kernel(
2058
- TestVecScalarOps,
2059
- f"test_scalar_multiplication_rightmul_{dtype.__name__}",
2060
- test_scalar_multiplication_rightmul,
2061
- devices=devices,
2062
- dtype=dtype,
2063
- )
2064
- add_function_test_register_kernel(
2065
- TestVecScalarOps,
2066
- f"test_cw_multiplication_{dtype.__name__}",
2067
- test_cw_multiplication,
2068
- devices=devices,
2069
- dtype=dtype,
2070
- )
2071
- add_function_test_register_kernel(
2072
- TestVecScalarOps, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2073
- )
2074
- add_function_test_register_kernel(
2075
- TestVecScalarOps, f"test_cw_division_{dtype.__name__}", test_cw_division, devices=devices, dtype=dtype
2076
- )
2077
- add_function_test_register_kernel(
2078
- TestVecScalarOps, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2079
- )
2080
- add_function_test_register_kernel(
2081
- TestVecScalarOps, f"test_dotproduct_{dtype.__name__}", test_dotproduct, devices=devices, dtype=dtype
2082
- )
2083
- add_function_test_register_kernel(
2084
- TestVecScalarOps, f"test_equivalent_types_{dtype.__name__}", test_equivalent_types, devices=devices, dtype=dtype
2085
- )
2086
- add_function_test_register_kernel(
2087
- TestVecScalarOps, f"test_conversions_{dtype.__name__}", test_conversions, devices=devices, dtype=dtype
2088
- )
2089
- add_function_test_register_kernel(
2090
- TestVecScalarOps, f"test_constants_{dtype.__name__}", test_constants, devices=devices, dtype=dtype
2091
- )
2092
-
2093
- # the kernels in this test compile incredibly slowly...
2094
- # add_function_test_register_kernel(TestVecScalarOps, f"test_minmax_{dtype.__name__}", test_minmax, devices=devices, dtype=dtype)
2095
-
2096
-
2097
- if __name__ == "__main__":
2098
- wp.build.clear_kernel_cache()
2099
- unittest.main(verbosity=2, failfast=True)
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
10
+ import numpy as np
11
+
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
14
+
15
+ wp.init()
16
+
17
+ np_signed_int_types = [
18
+ np.int8,
19
+ np.int16,
20
+ np.int32,
21
+ np.int64,
22
+ np.byte,
23
+ ]
24
+
25
+ np_unsigned_int_types = [
26
+ np.uint8,
27
+ np.uint16,
28
+ np.uint32,
29
+ np.uint64,
30
+ np.ubyte,
31
+ ]
32
+
33
+ np_int_types = np_signed_int_types + np_unsigned_int_types
34
+
35
+ np_float_types = [np.float16, np.float32, np.float64]
36
+
37
+ np_scalar_types = np_int_types + np_float_types
38
+
39
+
40
+ def randvals(rng, shape, dtype):
41
+ if dtype in np_float_types:
42
+ return rng.standard_normal(size=shape).astype(dtype)
43
+ elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
44
+ return rng.integers(1, high=3, size=shape, dtype=dtype)
45
+ return rng.integers(1, high=5, size=shape, dtype=dtype)
46
+
47
+
48
+ kernel_cache = {}
49
+
50
+
51
+ def getkernel(func, suffix=""):
52
+ key = func.__name__ + "_" + suffix
53
+ if key not in kernel_cache:
54
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
55
+ return kernel_cache[key]
56
+
57
+
58
+ def get_select_kernel(dtype):
59
+ def output_select_kernel_fn(
60
+ input: wp.array(dtype=dtype),
61
+ index: int,
62
+ out: wp.array(dtype=dtype),
63
+ ):
64
+ out[0] = input[index]
65
+
66
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
67
+
68
+
69
+ def get_select_kernel2(dtype):
70
+ def output_select_kernel2_fn(
71
+ input: wp.array(dtype=dtype, ndim=2),
72
+ index0: int,
73
+ index1: int,
74
+ out: wp.array(dtype=dtype),
75
+ ):
76
+ out[0] = input[index0, index1]
77
+
78
+ return getkernel(output_select_kernel2_fn, suffix=dtype.__name__)
79
+
80
+
81
+ def test_arrays(test, device, dtype):
82
+ rng = np.random.default_rng(123)
83
+
84
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
85
+ vec2 = wp.types.vector(length=2, dtype=wptype)
86
+ vec3 = wp.types.vector(length=3, dtype=wptype)
87
+ vec4 = wp.types.vector(length=4, dtype=wptype)
88
+ vec5 = wp.types.vector(length=5, dtype=wptype)
89
+
90
+ v2_np = randvals(rng, (10, 2), dtype)
91
+ v3_np = randvals(rng, (10, 3), dtype)
92
+ v4_np = randvals(rng, (10, 4), dtype)
93
+ v5_np = randvals(rng, (10, 5), dtype)
94
+
95
+ v2 = wp.array(v2_np, dtype=vec2, requires_grad=True, device=device)
96
+ v3 = wp.array(v3_np, dtype=vec3, requires_grad=True, device=device)
97
+ v4 = wp.array(v4_np, dtype=vec4, requires_grad=True, device=device)
98
+ v5 = wp.array(v5_np, dtype=vec5, requires_grad=True, device=device)
99
+
100
+ assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
101
+ assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
102
+ assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
103
+ assert_np_equal(v5.numpy(), v5_np, tol=1.0e-6)
104
+
105
+ vec2 = wp.types.vector(length=2, dtype=wptype)
106
+ vec3 = wp.types.vector(length=3, dtype=wptype)
107
+ vec4 = wp.types.vector(length=4, dtype=wptype)
108
+
109
+ v2 = wp.array(v2_np, dtype=vec2, requires_grad=True, device=device)
110
+ v3 = wp.array(v3_np, dtype=vec3, requires_grad=True, device=device)
111
+ v4 = wp.array(v4_np, dtype=vec4, requires_grad=True, device=device)
112
+
113
+ assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
114
+ assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
115
+ assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
116
+
117
+
118
+ def test_components(test, device, dtype):
119
+ # test accessing vector components from Python - this is especially important
120
+ # for float16, which requires special handling internally
121
+
122
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
123
+ vec3 = wp.types.vector(length=3, dtype=wptype)
124
+
125
+ v = vec3(1, 2, 3)
126
+
127
+ # test __getitem__ for individual components
128
+ test.assertEqual(v[0], 1)
129
+ test.assertEqual(v[1], 2)
130
+ test.assertEqual(v[2], 3)
131
+
132
+ # test __getitem__ for slices
133
+ s = v[:]
134
+ test.assertEqual(s[0], 1)
135
+ test.assertEqual(s[1], 2)
136
+ test.assertEqual(s[2], 3)
137
+
138
+ s = v[1:]
139
+ test.assertEqual(s[0], 2)
140
+ test.assertEqual(s[1], 3)
141
+
142
+ s = v[:2]
143
+ test.assertEqual(s[0], 1)
144
+ test.assertEqual(s[1], 2)
145
+
146
+ s = v[::2]
147
+ test.assertEqual(s[0], 1)
148
+ test.assertEqual(s[1], 3)
149
+
150
+ # test __setitem__ for individual components
151
+ v[0] = 4
152
+ v[1] = 5
153
+ v[2] = 6
154
+ test.assertEqual(v[0], 4)
155
+ test.assertEqual(v[1], 5)
156
+ test.assertEqual(v[2], 6)
157
+
158
+ # test __setitem__ for slices
159
+ v[:] = [7, 8, 9]
160
+ test.assertEqual(v[0], 7)
161
+ test.assertEqual(v[1], 8)
162
+ test.assertEqual(v[2], 9)
163
+
164
+ v[1:] = [10, 11]
165
+ test.assertEqual(v[0], 7)
166
+ test.assertEqual(v[1], 10)
167
+ test.assertEqual(v[2], 11)
168
+
169
+ v[:2] = [12, 13]
170
+ test.assertEqual(v[0], 12)
171
+ test.assertEqual(v[1], 13)
172
+ test.assertEqual(v[2], 11)
173
+
174
+ v[::2] = [14, 15]
175
+ test.assertEqual(v[0], 14)
176
+ test.assertEqual(v[1], 13)
177
+ test.assertEqual(v[2], 15)
178
+
179
+
180
+ def test_py_arithmetic_ops(test, device, dtype):
181
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
182
+
183
+ def make_vec(*args):
184
+ if wptype in wp.types.int_types:
185
+ # Cast to the correct integer type to simulate wrapping.
186
+ return tuple(wptype._type_(x).value for x in args)
187
+
188
+ return args
189
+
190
+ vec_cls = wp.vec(3, wptype)
191
+
192
+ v = vec_cls(1, -2, 3)
193
+ test.assertSequenceEqual(+v, make_vec(1, -2, 3))
194
+ test.assertSequenceEqual(-v, make_vec(-1, 2, -3))
195
+ test.assertSequenceEqual(v + vec_cls(5, 5, 5), make_vec(6, 3, 8))
196
+ test.assertSequenceEqual(v - vec_cls(5, 5, 5), make_vec(-4, -7, -2))
197
+
198
+ v = vec_cls(2, 4, 6)
199
+ test.assertSequenceEqual(v * wptype(2), make_vec(4, 8, 12))
200
+ test.assertSequenceEqual(wptype(2) * v, make_vec(4, 8, 12))
201
+ test.assertSequenceEqual(v / wptype(2), make_vec(1, 2, 3))
202
+ test.assertSequenceEqual(wptype(24) / v, make_vec(12, 6, 4))
203
+
204
+
205
+ def test_constructors(test, device, dtype, register_kernels=False):
206
+ rng = np.random.default_rng(123)
207
+
208
+ tol = {
209
+ np.float16: 5.0e-3,
210
+ np.float32: 1.0e-6,
211
+ np.float64: 1.0e-8,
212
+ }.get(dtype, 0)
213
+
214
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
215
+ vec2 = wp.types.vector(length=2, dtype=wptype)
216
+ vec3 = wp.types.vector(length=3, dtype=wptype)
217
+ vec4 = wp.types.vector(length=4, dtype=wptype)
218
+ vec5 = wp.types.vector(length=5, dtype=wptype)
219
+
220
+ def check_scalar_constructor(
221
+ input: wp.array(dtype=wptype),
222
+ v2: wp.array(dtype=vec2),
223
+ v3: wp.array(dtype=vec3),
224
+ v4: wp.array(dtype=vec4),
225
+ v5: wp.array(dtype=vec5),
226
+ v20: wp.array(dtype=wptype),
227
+ v21: wp.array(dtype=wptype),
228
+ v30: wp.array(dtype=wptype),
229
+ v31: wp.array(dtype=wptype),
230
+ v32: wp.array(dtype=wptype),
231
+ v40: wp.array(dtype=wptype),
232
+ v41: wp.array(dtype=wptype),
233
+ v42: wp.array(dtype=wptype),
234
+ v43: wp.array(dtype=wptype),
235
+ v50: wp.array(dtype=wptype),
236
+ v51: wp.array(dtype=wptype),
237
+ v52: wp.array(dtype=wptype),
238
+ v53: wp.array(dtype=wptype),
239
+ v54: wp.array(dtype=wptype),
240
+ ):
241
+ v2result = vec2(input[0])
242
+ v3result = vec3(input[0])
243
+ v4result = vec4(input[0])
244
+ v5result = vec5(input[0])
245
+
246
+ v2[0] = v2result
247
+ v3[0] = v3result
248
+ v4[0] = v4result
249
+ v5[0] = v5result
250
+
251
+ # multiply outputs by 2 so we've got something to backpropagate
252
+ v20[0] = wptype(2) * v2result[0]
253
+ v21[0] = wptype(2) * v2result[1]
254
+
255
+ v30[0] = wptype(2) * v3result[0]
256
+ v31[0] = wptype(2) * v3result[1]
257
+ v32[0] = wptype(2) * v3result[2]
258
+
259
+ v40[0] = wptype(2) * v4result[0]
260
+ v41[0] = wptype(2) * v4result[1]
261
+ v42[0] = wptype(2) * v4result[2]
262
+ v43[0] = wptype(2) * v4result[3]
263
+
264
+ v50[0] = wptype(2) * v5result[0]
265
+ v51[0] = wptype(2) * v5result[1]
266
+ v52[0] = wptype(2) * v5result[2]
267
+ v53[0] = wptype(2) * v5result[3]
268
+ v54[0] = wptype(2) * v5result[4]
269
+
270
+ def check_vector_constructors(
271
+ input: wp.array(dtype=wptype),
272
+ v2: wp.array(dtype=vec2),
273
+ v3: wp.array(dtype=vec3),
274
+ v4: wp.array(dtype=vec4),
275
+ v5: wp.array(dtype=vec5),
276
+ v20: wp.array(dtype=wptype),
277
+ v21: wp.array(dtype=wptype),
278
+ v30: wp.array(dtype=wptype),
279
+ v31: wp.array(dtype=wptype),
280
+ v32: wp.array(dtype=wptype),
281
+ v40: wp.array(dtype=wptype),
282
+ v41: wp.array(dtype=wptype),
283
+ v42: wp.array(dtype=wptype),
284
+ v43: wp.array(dtype=wptype),
285
+ v50: wp.array(dtype=wptype),
286
+ v51: wp.array(dtype=wptype),
287
+ v52: wp.array(dtype=wptype),
288
+ v53: wp.array(dtype=wptype),
289
+ v54: wp.array(dtype=wptype),
290
+ ):
291
+ v2result = vec2(input[0], input[1])
292
+ v3result = vec3(input[2], input[3], input[4])
293
+ v4result = vec4(input[5], input[6], input[7], input[8])
294
+ v5result = vec5(input[9], input[10], input[11], input[12], input[13])
295
+
296
+ v2[0] = v2result
297
+ v3[0] = v3result
298
+ v4[0] = v4result
299
+ v5[0] = v5result
300
+
301
+ # multiply the output by 2 so we've got something to backpropagate:
302
+ v20[0] = wptype(2) * v2result[0]
303
+ v21[0] = wptype(2) * v2result[1]
304
+
305
+ v30[0] = wptype(2) * v3result[0]
306
+ v31[0] = wptype(2) * v3result[1]
307
+ v32[0] = wptype(2) * v3result[2]
308
+
309
+ v40[0] = wptype(2) * v4result[0]
310
+ v41[0] = wptype(2) * v4result[1]
311
+ v42[0] = wptype(2) * v4result[2]
312
+ v43[0] = wptype(2) * v4result[3]
313
+
314
+ v50[0] = wptype(2) * v5result[0]
315
+ v51[0] = wptype(2) * v5result[1]
316
+ v52[0] = wptype(2) * v5result[2]
317
+ v53[0] = wptype(2) * v5result[3]
318
+ v54[0] = wptype(2) * v5result[4]
319
+
320
+ vec_kernel = getkernel(check_vector_constructors, suffix=dtype.__name__)
321
+ kernel = getkernel(check_scalar_constructor, suffix=dtype.__name__)
322
+
323
+ if register_kernels:
324
+ return
325
+
326
+ input = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
327
+ v2 = wp.zeros(1, dtype=vec2, device=device)
328
+ v3 = wp.zeros(1, dtype=vec3, device=device)
329
+ v4 = wp.zeros(1, dtype=vec4, device=device)
330
+ v5 = wp.zeros(1, dtype=vec5, device=device)
331
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
332
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
333
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
334
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
335
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
336
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
337
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
338
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
339
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
340
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
341
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
342
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
343
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
344
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
345
+
346
+ tape = wp.Tape()
347
+ with tape:
348
+ wp.launch(
349
+ kernel,
350
+ dim=1,
351
+ inputs=[input],
352
+ outputs=[v2, v3, v4, v5, v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
353
+ device=device,
354
+ )
355
+
356
+ if dtype in np_float_types:
357
+ for l in [v20, v21]:
358
+ tape.backward(loss=l)
359
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
360
+ tape.zero()
361
+
362
+ for l in [v30, v31, v32]:
363
+ tape.backward(loss=l)
364
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
365
+ tape.zero()
366
+
367
+ for l in [v40, v41, v42, v43]:
368
+ tape.backward(loss=l)
369
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
370
+ tape.zero()
371
+
372
+ for l in [v50, v51, v52, v53, v54]:
373
+ tape.backward(loss=l)
374
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
375
+ tape.zero()
376
+
377
+ val = input.numpy()[0]
378
+ assert_np_equal(v2.numpy()[0], np.array([val, val]), tol=1.0e-6)
379
+ assert_np_equal(v3.numpy()[0], np.array([val, val, val]), tol=1.0e-6)
380
+ assert_np_equal(v4.numpy()[0], np.array([val, val, val, val]), tol=1.0e-6)
381
+ assert_np_equal(v5.numpy()[0], np.array([val, val, val, val, val]), tol=1.0e-6)
382
+
383
+ assert_np_equal(v20.numpy()[0], 2 * val, tol=1.0e-6)
384
+ assert_np_equal(v21.numpy()[0], 2 * val, tol=1.0e-6)
385
+ assert_np_equal(v30.numpy()[0], 2 * val, tol=1.0e-6)
386
+ assert_np_equal(v31.numpy()[0], 2 * val, tol=1.0e-6)
387
+ assert_np_equal(v32.numpy()[0], 2 * val, tol=1.0e-6)
388
+ assert_np_equal(v40.numpy()[0], 2 * val, tol=1.0e-6)
389
+ assert_np_equal(v41.numpy()[0], 2 * val, tol=1.0e-6)
390
+ assert_np_equal(v42.numpy()[0], 2 * val, tol=1.0e-6)
391
+ assert_np_equal(v43.numpy()[0], 2 * val, tol=1.0e-6)
392
+ assert_np_equal(v50.numpy()[0], 2 * val, tol=1.0e-6)
393
+ assert_np_equal(v51.numpy()[0], 2 * val, tol=1.0e-6)
394
+ assert_np_equal(v52.numpy()[0], 2 * val, tol=1.0e-6)
395
+ assert_np_equal(v53.numpy()[0], 2 * val, tol=1.0e-6)
396
+ assert_np_equal(v54.numpy()[0], 2 * val, tol=1.0e-6)
397
+
398
+ input = wp.array(randvals(rng, [14], dtype), requires_grad=True, device=device)
399
+ tape = wp.Tape()
400
+ with tape:
401
+ wp.launch(
402
+ vec_kernel,
403
+ dim=1,
404
+ inputs=[input],
405
+ outputs=[v2, v3, v4, v5, v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
406
+ device=device,
407
+ )
408
+
409
+ if dtype in np_float_types:
410
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
411
+ tape.backward(loss=l)
412
+ grad = tape.gradients[input].numpy()
413
+ expected_grad = np.zeros_like(grad)
414
+ expected_grad[i] = 2
415
+ assert_np_equal(grad, expected_grad, tol=tol)
416
+ tape.zero()
417
+
418
+ assert_np_equal(v2.numpy()[0, 0], input.numpy()[0], tol=tol)
419
+ assert_np_equal(v2.numpy()[0, 1], input.numpy()[1], tol=tol)
420
+ assert_np_equal(v3.numpy()[0, 0], input.numpy()[2], tol=tol)
421
+ assert_np_equal(v3.numpy()[0, 1], input.numpy()[3], tol=tol)
422
+ assert_np_equal(v3.numpy()[0, 2], input.numpy()[4], tol=tol)
423
+ assert_np_equal(v4.numpy()[0, 0], input.numpy()[5], tol=tol)
424
+ assert_np_equal(v4.numpy()[0, 1], input.numpy()[6], tol=tol)
425
+ assert_np_equal(v4.numpy()[0, 2], input.numpy()[7], tol=tol)
426
+ assert_np_equal(v4.numpy()[0, 3], input.numpy()[8], tol=tol)
427
+ assert_np_equal(v5.numpy()[0, 0], input.numpy()[9], tol=tol)
428
+ assert_np_equal(v5.numpy()[0, 1], input.numpy()[10], tol=tol)
429
+ assert_np_equal(v5.numpy()[0, 2], input.numpy()[11], tol=tol)
430
+ assert_np_equal(v5.numpy()[0, 3], input.numpy()[12], tol=tol)
431
+ assert_np_equal(v5.numpy()[0, 4], input.numpy()[13], tol=tol)
432
+
433
+ assert_np_equal(v20.numpy()[0], 2 * input.numpy()[0], tol=tol)
434
+ assert_np_equal(v21.numpy()[0], 2 * input.numpy()[1], tol=tol)
435
+ assert_np_equal(v30.numpy()[0], 2 * input.numpy()[2], tol=tol)
436
+ assert_np_equal(v31.numpy()[0], 2 * input.numpy()[3], tol=tol)
437
+ assert_np_equal(v32.numpy()[0], 2 * input.numpy()[4], tol=tol)
438
+ assert_np_equal(v40.numpy()[0], 2 * input.numpy()[5], tol=tol)
439
+ assert_np_equal(v41.numpy()[0], 2 * input.numpy()[6], tol=tol)
440
+ assert_np_equal(v42.numpy()[0], 2 * input.numpy()[7], tol=tol)
441
+ assert_np_equal(v43.numpy()[0], 2 * input.numpy()[8], tol=tol)
442
+ assert_np_equal(v50.numpy()[0], 2 * input.numpy()[9], tol=tol)
443
+ assert_np_equal(v51.numpy()[0], 2 * input.numpy()[10], tol=tol)
444
+ assert_np_equal(v52.numpy()[0], 2 * input.numpy()[11], tol=tol)
445
+ assert_np_equal(v53.numpy()[0], 2 * input.numpy()[12], tol=tol)
446
+ assert_np_equal(v54.numpy()[0], 2 * input.numpy()[13], tol=tol)
447
+
448
+
449
+ def test_anon_type_instance(test, device, dtype, register_kernels=False):
450
+ rng = np.random.default_rng(123)
451
+
452
+ tol = {
453
+ np.float16: 5.0e-3,
454
+ np.float32: 1.0e-6,
455
+ np.float64: 1.0e-8,
456
+ }.get(dtype, 0)
457
+
458
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
459
+
460
+ def check_scalar_init(
461
+ input: wp.array(dtype=wptype),
462
+ output: wp.array(dtype=wptype),
463
+ ):
464
+ v2result = wp.vector(input[0], length=2)
465
+ v3result = wp.vector(input[1], length=3)
466
+ v4result = wp.vector(input[2], length=4)
467
+ v5result = wp.vector(input[3], length=5)
468
+
469
+ idx = 0
470
+ for i in range(2):
471
+ output[idx] = wptype(2) * v2result[i]
472
+ idx = idx + 1
473
+ for i in range(3):
474
+ output[idx] = wptype(2) * v3result[i]
475
+ idx = idx + 1
476
+ for i in range(4):
477
+ output[idx] = wptype(2) * v4result[i]
478
+ idx = idx + 1
479
+ for i in range(5):
480
+ output[idx] = wptype(2) * v5result[i]
481
+ idx = idx + 1
482
+
483
+ def check_component_init(
484
+ input: wp.array(dtype=wptype),
485
+ output: wp.array(dtype=wptype),
486
+ ):
487
+ v2result = wp.vector(input[0], input[1])
488
+ v3result = wp.vector(input[2], input[3], input[4])
489
+ v4result = wp.vector(input[5], input[6], input[7], input[8])
490
+ v5result = wp.vector(input[9], input[10], input[11], input[12], input[13])
491
+
492
+ idx = 0
493
+ for i in range(2):
494
+ output[idx] = wptype(2) * v2result[i]
495
+ idx = idx + 1
496
+ for i in range(3):
497
+ output[idx] = wptype(2) * v3result[i]
498
+ idx = idx + 1
499
+ for i in range(4):
500
+ output[idx] = wptype(2) * v4result[i]
501
+ idx = idx + 1
502
+ for i in range(5):
503
+ output[idx] = wptype(2) * v5result[i]
504
+ idx = idx + 1
505
+
506
+ scalar_kernel = getkernel(check_scalar_init, suffix=dtype.__name__)
507
+ component_kernel = getkernel(check_component_init, suffix=dtype.__name__)
508
+ output_select_kernel = get_select_kernel(wptype)
509
+
510
+ if register_kernels:
511
+ return
512
+
513
+ input = wp.array(randvals(rng, [4], dtype), requires_grad=True, device=device)
514
+ output = wp.zeros(2 + 3 + 4 + 5, dtype=wptype, requires_grad=True, device=device)
515
+
516
+ wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
517
+
518
+ assert_np_equal(output.numpy()[:2], 2 * np.array([input.numpy()[0]] * 2), tol=1.0e-6)
519
+ assert_np_equal(output.numpy()[2:5], 2 * np.array([input.numpy()[1]] * 3), tol=1.0e-6)
520
+ assert_np_equal(output.numpy()[5:9], 2 * np.array([input.numpy()[2]] * 4), tol=1.0e-6)
521
+ assert_np_equal(output.numpy()[9:], 2 * np.array([input.numpy()[3]] * 5), tol=1.0e-6)
522
+
523
+ if dtype in np_float_types:
524
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
525
+ for i in range(len(output)):
526
+ tape = wp.Tape()
527
+ with tape:
528
+ wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
529
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
530
+
531
+ tape.backward(loss=out)
532
+ expected = np.zeros_like(input.numpy())
533
+ if i < 2:
534
+ expected[0] = 2
535
+ elif i < 5:
536
+ expected[1] = 2
537
+ elif i < 9:
538
+ expected[2] = 2
539
+ else:
540
+ expected[3] = 2
541
+
542
+ assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
543
+
544
+ tape.reset()
545
+ tape.zero()
546
+
547
+ input = wp.array(randvals(rng, [2 + 3 + 4 + 5], dtype), requires_grad=True, device=device)
548
+ output = wp.zeros(2 + 3 + 4 + 5, dtype=wptype, requires_grad=True, device=device)
549
+
550
+ wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
551
+
552
+ assert_np_equal(output.numpy(), 2 * input.numpy(), tol=1.0e-6)
553
+
554
+ if dtype in np_float_types:
555
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
556
+ for i in range(len(output)):
557
+ tape = wp.Tape()
558
+ with tape:
559
+ wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
560
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
561
+
562
+ tape.backward(loss=out)
563
+ expected = np.zeros_like(input.numpy())
564
+ expected[i] = 2
565
+
566
+ assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
567
+
568
+ tape.reset()
569
+ tape.zero()
570
+
571
+
572
+ def test_indexing(test, device, dtype, register_kernels=False):
573
+ rng = np.random.default_rng(123)
574
+
575
+ tol = {
576
+ np.float16: 5.0e-3,
577
+ np.float32: 1.0e-6,
578
+ np.float64: 1.0e-8,
579
+ }.get(dtype, 0)
580
+
581
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
582
+ vec2 = wp.types.vector(length=2, dtype=wptype)
583
+ vec3 = wp.types.vector(length=3, dtype=wptype)
584
+ vec4 = wp.types.vector(length=4, dtype=wptype)
585
+ vec5 = wp.types.vector(length=5, dtype=wptype)
586
+
587
+ def check_indexing(
588
+ v2: wp.array(dtype=vec2),
589
+ v3: wp.array(dtype=vec3),
590
+ v4: wp.array(dtype=vec4),
591
+ v5: wp.array(dtype=vec5),
592
+ v20: wp.array(dtype=wptype),
593
+ v21: wp.array(dtype=wptype),
594
+ v30: wp.array(dtype=wptype),
595
+ v31: wp.array(dtype=wptype),
596
+ v32: wp.array(dtype=wptype),
597
+ v40: wp.array(dtype=wptype),
598
+ v41: wp.array(dtype=wptype),
599
+ v42: wp.array(dtype=wptype),
600
+ v43: wp.array(dtype=wptype),
601
+ v50: wp.array(dtype=wptype),
602
+ v51: wp.array(dtype=wptype),
603
+ v52: wp.array(dtype=wptype),
604
+ v53: wp.array(dtype=wptype),
605
+ v54: wp.array(dtype=wptype),
606
+ ):
607
+ # multiply outputs by 2 so we've got something to backpropagate:
608
+ v20[0] = wptype(2) * v2[0][0]
609
+ v21[0] = wptype(2) * v2[0][1]
610
+
611
+ v30[0] = wptype(2) * v3[0][0]
612
+ v31[0] = wptype(2) * v3[0][1]
613
+ v32[0] = wptype(2) * v3[0][2]
614
+
615
+ v40[0] = wptype(2) * v4[0][0]
616
+ v41[0] = wptype(2) * v4[0][1]
617
+ v42[0] = wptype(2) * v4[0][2]
618
+ v43[0] = wptype(2) * v4[0][3]
619
+
620
+ v50[0] = wptype(2) * v5[0][0]
621
+ v51[0] = wptype(2) * v5[0][1]
622
+ v52[0] = wptype(2) * v5[0][2]
623
+ v53[0] = wptype(2) * v5[0][3]
624
+ v54[0] = wptype(2) * v5[0][4]
625
+
626
+ kernel = getkernel(check_indexing, suffix=dtype.__name__)
627
+
628
+ if register_kernels:
629
+ return
630
+
631
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
632
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
633
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
634
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
635
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
636
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
637
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
638
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
639
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
640
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
641
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
642
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
643
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
644
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
645
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
646
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
647
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
648
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
649
+
650
+ tape = wp.Tape()
651
+ with tape:
652
+ wp.launch(
653
+ kernel,
654
+ dim=1,
655
+ inputs=[v2, v3, v4, v5],
656
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
657
+ device=device,
658
+ )
659
+
660
+ if dtype in np_float_types:
661
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
662
+ tape.backward(loss=l)
663
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
664
+ expected_grads = np.zeros_like(allgrads)
665
+ expected_grads[i] = 2
666
+ assert_np_equal(allgrads, expected_grads, tol=tol)
667
+ tape.zero()
668
+
669
+ assert_np_equal(v20.numpy()[0], 2.0 * v2.numpy()[0, 0], tol=tol)
670
+ assert_np_equal(v21.numpy()[0], 2.0 * v2.numpy()[0, 1], tol=tol)
671
+ assert_np_equal(v30.numpy()[0], 2.0 * v3.numpy()[0, 0], tol=tol)
672
+ assert_np_equal(v31.numpy()[0], 2.0 * v3.numpy()[0, 1], tol=tol)
673
+ assert_np_equal(v32.numpy()[0], 2.0 * v3.numpy()[0, 2], tol=tol)
674
+ assert_np_equal(v40.numpy()[0], 2.0 * v4.numpy()[0, 0], tol=tol)
675
+ assert_np_equal(v41.numpy()[0], 2.0 * v4.numpy()[0, 1], tol=tol)
676
+ assert_np_equal(v42.numpy()[0], 2.0 * v4.numpy()[0, 2], tol=tol)
677
+ assert_np_equal(v43.numpy()[0], 2.0 * v4.numpy()[0, 3], tol=tol)
678
+ assert_np_equal(v50.numpy()[0], 2.0 * v5.numpy()[0, 0], tol=tol)
679
+ assert_np_equal(v51.numpy()[0], 2.0 * v5.numpy()[0, 1], tol=tol)
680
+ assert_np_equal(v52.numpy()[0], 2.0 * v5.numpy()[0, 2], tol=tol)
681
+ assert_np_equal(v53.numpy()[0], 2.0 * v5.numpy()[0, 3], tol=tol)
682
+ assert_np_equal(v54.numpy()[0], 2.0 * v5.numpy()[0, 4], tol=tol)
683
+
684
+
685
+ def test_equality(test, device, dtype, register_kernels=False):
686
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
687
+ vec2 = wp.types.vector(length=2, dtype=wptype)
688
+ vec3 = wp.types.vector(length=3, dtype=wptype)
689
+ vec4 = wp.types.vector(length=4, dtype=wptype)
690
+ vec5 = wp.types.vector(length=5, dtype=wptype)
691
+
692
+ def check_equality(
693
+ v20: wp.array(dtype=vec2),
694
+ v21: wp.array(dtype=vec2),
695
+ v22: wp.array(dtype=vec2),
696
+ v30: wp.array(dtype=vec3),
697
+ v31: wp.array(dtype=vec3),
698
+ v32: wp.array(dtype=vec3),
699
+ v33: wp.array(dtype=vec3),
700
+ v40: wp.array(dtype=vec4),
701
+ v41: wp.array(dtype=vec4),
702
+ v42: wp.array(dtype=vec4),
703
+ v43: wp.array(dtype=vec4),
704
+ v44: wp.array(dtype=vec4),
705
+ v50: wp.array(dtype=vec5),
706
+ v51: wp.array(dtype=vec5),
707
+ v52: wp.array(dtype=vec5),
708
+ v53: wp.array(dtype=vec5),
709
+ v54: wp.array(dtype=vec5),
710
+ v55: wp.array(dtype=vec5),
711
+ ):
712
+ wp.expect_eq(v20[0], v20[0])
713
+ wp.expect_neq(v21[0], v20[0])
714
+ wp.expect_neq(v22[0], v20[0])
715
+
716
+ wp.expect_eq(v30[0], v30[0])
717
+ wp.expect_neq(v31[0], v30[0])
718
+ wp.expect_neq(v32[0], v30[0])
719
+ wp.expect_neq(v33[0], v30[0])
720
+
721
+ wp.expect_eq(v40[0], v40[0])
722
+ wp.expect_neq(v41[0], v40[0])
723
+ wp.expect_neq(v42[0], v40[0])
724
+ wp.expect_neq(v43[0], v40[0])
725
+ wp.expect_neq(v44[0], v40[0])
726
+
727
+ wp.expect_eq(v50[0], v50[0])
728
+ wp.expect_neq(v51[0], v50[0])
729
+ wp.expect_neq(v52[0], v50[0])
730
+ wp.expect_neq(v53[0], v50[0])
731
+ wp.expect_neq(v54[0], v50[0])
732
+ wp.expect_neq(v55[0], v50[0])
733
+
734
+ kernel = getkernel(check_equality, suffix=dtype.__name__)
735
+
736
+ if register_kernels:
737
+ return
738
+
739
+ v20 = wp.array([1.0, 2.0], dtype=vec2, requires_grad=True, device=device)
740
+ v21 = wp.array([1.0, 3.0], dtype=vec2, requires_grad=True, device=device)
741
+ v22 = wp.array([3.0, 2.0], dtype=vec2, requires_grad=True, device=device)
742
+
743
+ v30 = wp.array([1.0, 2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
744
+ v31 = wp.array([-1.0, 2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
745
+ v32 = wp.array([1.0, -2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
746
+ v33 = wp.array([1.0, 2.0, -3.0], dtype=vec3, requires_grad=True, device=device)
747
+
748
+ v40 = wp.array([1.0, 2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
749
+ v41 = wp.array([-1.0, 2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
750
+ v42 = wp.array([1.0, -2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
751
+ v43 = wp.array([1.0, 2.0, -3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
752
+ v44 = wp.array([1.0, 2.0, 3.0, -4.0], dtype=vec4, requires_grad=True, device=device)
753
+
754
+ v50 = wp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
755
+ v51 = wp.array([-1.0, 2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
756
+ v52 = wp.array([1.0, -2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
757
+ v53 = wp.array([1.0, 2.0, -3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
758
+ v54 = wp.array([1.0, 2.0, 3.0, -4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
759
+ v55 = wp.array([1.0, 2.0, 3.0, 4.0, -5.0], dtype=vec5, requires_grad=True, device=device)
760
+ wp.launch(
761
+ kernel,
762
+ dim=1,
763
+ inputs=[
764
+ v20,
765
+ v21,
766
+ v22,
767
+ v30,
768
+ v31,
769
+ v32,
770
+ v33,
771
+ v40,
772
+ v41,
773
+ v42,
774
+ v43,
775
+ v44,
776
+ v50,
777
+ v51,
778
+ v52,
779
+ v53,
780
+ v54,
781
+ v55,
782
+ ],
783
+ outputs=[],
784
+ device=device,
785
+ )
786
+
787
+
788
+ def test_scalar_multiplication(test, device, dtype, register_kernels=False):
789
+ rng = np.random.default_rng(123)
790
+
791
+ tol = {
792
+ np.float16: 5.0e-3,
793
+ np.float32: 1.0e-6,
794
+ np.float64: 1.0e-8,
795
+ }.get(dtype, 0)
796
+
797
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
798
+ vec2 = wp.types.vector(length=2, dtype=wptype)
799
+ vec3 = wp.types.vector(length=3, dtype=wptype)
800
+ vec4 = wp.types.vector(length=4, dtype=wptype)
801
+ vec5 = wp.types.vector(length=5, dtype=wptype)
802
+
803
+ def check_mul(
804
+ s: wp.array(dtype=wptype),
805
+ v2: wp.array(dtype=vec2),
806
+ v3: wp.array(dtype=vec3),
807
+ v4: wp.array(dtype=vec4),
808
+ v5: wp.array(dtype=vec5),
809
+ v20: wp.array(dtype=wptype),
810
+ v21: wp.array(dtype=wptype),
811
+ v30: wp.array(dtype=wptype),
812
+ v31: wp.array(dtype=wptype),
813
+ v32: wp.array(dtype=wptype),
814
+ v40: wp.array(dtype=wptype),
815
+ v41: wp.array(dtype=wptype),
816
+ v42: wp.array(dtype=wptype),
817
+ v43: wp.array(dtype=wptype),
818
+ v50: wp.array(dtype=wptype),
819
+ v51: wp.array(dtype=wptype),
820
+ v52: wp.array(dtype=wptype),
821
+ v53: wp.array(dtype=wptype),
822
+ v54: wp.array(dtype=wptype),
823
+ ):
824
+ v2result = s[0] * v2[0]
825
+ v3result = s[0] * v3[0]
826
+ v4result = s[0] * v4[0]
827
+ v5result = s[0] * v5[0]
828
+
829
+ # multiply outputs by 2 so we've got something to backpropagate:
830
+ v20[0] = wptype(2) * v2result[0]
831
+ v21[0] = wptype(2) * v2result[1]
832
+
833
+ v30[0] = wptype(2) * v3result[0]
834
+ v31[0] = wptype(2) * v3result[1]
835
+ v32[0] = wptype(2) * v3result[2]
836
+
837
+ v40[0] = wptype(2) * v4result[0]
838
+ v41[0] = wptype(2) * v4result[1]
839
+ v42[0] = wptype(2) * v4result[2]
840
+ v43[0] = wptype(2) * v4result[3]
841
+
842
+ v50[0] = wptype(2) * v5result[0]
843
+ v51[0] = wptype(2) * v5result[1]
844
+ v52[0] = wptype(2) * v5result[2]
845
+ v53[0] = wptype(2) * v5result[3]
846
+ v54[0] = wptype(2) * v5result[4]
847
+
848
+ kernel = getkernel(check_mul, suffix=dtype.__name__)
849
+
850
+ if register_kernels:
851
+ return
852
+
853
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
854
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
855
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
856
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
857
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
858
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
859
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
860
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
861
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
862
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
863
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
864
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
865
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
866
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
867
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
868
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
869
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
870
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
871
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
872
+ tape = wp.Tape()
873
+ with tape:
874
+ wp.launch(
875
+ kernel,
876
+ dim=1,
877
+ inputs=[
878
+ s,
879
+ v2,
880
+ v3,
881
+ v4,
882
+ v5,
883
+ ],
884
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
885
+ device=device,
886
+ )
887
+
888
+ assert_np_equal(v20.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 0], tol=tol)
889
+ assert_np_equal(v21.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 1], tol=tol)
890
+
891
+ assert_np_equal(v30.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 0], tol=10 * tol)
892
+ assert_np_equal(v31.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 1], tol=10 * tol)
893
+ assert_np_equal(v32.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 2], tol=10 * tol)
894
+
895
+ assert_np_equal(v40.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 0], tol=10 * tol)
896
+ assert_np_equal(v41.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 1], tol=10 * tol)
897
+ assert_np_equal(v42.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 2], tol=10 * tol)
898
+ assert_np_equal(v43.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 3], tol=10 * tol)
899
+
900
+ assert_np_equal(v50.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 0], tol=10 * tol)
901
+ assert_np_equal(v51.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 1], tol=10 * tol)
902
+ assert_np_equal(v52.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 2], tol=10 * tol)
903
+ assert_np_equal(v53.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 3], tol=10 * tol)
904
+ assert_np_equal(v54.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 4], tol=10 * tol)
905
+
906
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
907
+
908
+ if dtype in np_float_types:
909
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43]):
910
+ tape.backward(loss=l)
911
+ sgrad = tape.gradients[s].numpy()[0]
912
+ assert_np_equal(sgrad, 2 * incmps[i], tol=10 * tol)
913
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4]])
914
+ expected_grads = np.zeros_like(allgrads)
915
+ expected_grads[i] = s.numpy()[0] * 2
916
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
917
+ tape.zero()
918
+
919
+
920
+ def test_scalar_multiplication_rightmul(test, device, dtype, register_kernels=False):
921
+ rng = np.random.default_rng(123)
922
+
923
+ tol = {
924
+ np.float16: 5.0e-3,
925
+ np.float32: 1.0e-6,
926
+ np.float64: 1.0e-8,
927
+ }.get(dtype, 0)
928
+
929
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
930
+ vec2 = wp.types.vector(length=2, dtype=wptype)
931
+ vec3 = wp.types.vector(length=3, dtype=wptype)
932
+ vec4 = wp.types.vector(length=4, dtype=wptype)
933
+ vec5 = wp.types.vector(length=5, dtype=wptype)
934
+
935
+ def check_rightmul(
936
+ s: wp.array(dtype=wptype),
937
+ v2: wp.array(dtype=vec2),
938
+ v3: wp.array(dtype=vec3),
939
+ v4: wp.array(dtype=vec4),
940
+ v5: wp.array(dtype=vec5),
941
+ v20: wp.array(dtype=wptype),
942
+ v21: wp.array(dtype=wptype),
943
+ v30: wp.array(dtype=wptype),
944
+ v31: wp.array(dtype=wptype),
945
+ v32: wp.array(dtype=wptype),
946
+ v40: wp.array(dtype=wptype),
947
+ v41: wp.array(dtype=wptype),
948
+ v42: wp.array(dtype=wptype),
949
+ v43: wp.array(dtype=wptype),
950
+ v50: wp.array(dtype=wptype),
951
+ v51: wp.array(dtype=wptype),
952
+ v52: wp.array(dtype=wptype),
953
+ v53: wp.array(dtype=wptype),
954
+ v54: wp.array(dtype=wptype),
955
+ ):
956
+ v2result = v2[0] * s[0]
957
+ v3result = v3[0] * s[0]
958
+ v4result = v4[0] * s[0]
959
+ v5result = v5[0] * s[0]
960
+
961
+ # multiply outputs by 2 so we've got something to backpropagate:
962
+ v20[0] = wptype(2) * v2result[0]
963
+ v21[0] = wptype(2) * v2result[1]
964
+
965
+ v30[0] = wptype(2) * v3result[0]
966
+ v31[0] = wptype(2) * v3result[1]
967
+ v32[0] = wptype(2) * v3result[2]
968
+
969
+ v40[0] = wptype(2) * v4result[0]
970
+ v41[0] = wptype(2) * v4result[1]
971
+ v42[0] = wptype(2) * v4result[2]
972
+ v43[0] = wptype(2) * v4result[3]
973
+
974
+ v50[0] = wptype(2) * v5result[0]
975
+ v51[0] = wptype(2) * v5result[1]
976
+ v52[0] = wptype(2) * v5result[2]
977
+ v53[0] = wptype(2) * v5result[3]
978
+ v54[0] = wptype(2) * v5result[4]
979
+
980
+ kernel = getkernel(check_rightmul, suffix=dtype.__name__)
981
+
982
+ if register_kernels:
983
+ return
984
+
985
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
986
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
987
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
988
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
989
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
990
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
991
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
992
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
993
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
994
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
995
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
996
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
997
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
998
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
999
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1000
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1001
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1002
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1003
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1004
+ tape = wp.Tape()
1005
+ with tape:
1006
+ wp.launch(
1007
+ kernel,
1008
+ dim=1,
1009
+ inputs=[
1010
+ s,
1011
+ v2,
1012
+ v3,
1013
+ v4,
1014
+ v5,
1015
+ ],
1016
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1017
+ device=device,
1018
+ )
1019
+
1020
+ assert_np_equal(v20.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 0], tol=tol)
1021
+ assert_np_equal(v21.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 1], tol=tol)
1022
+
1023
+ assert_np_equal(v30.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 0], tol=10 * tol)
1024
+ assert_np_equal(v31.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 1], tol=10 * tol)
1025
+ assert_np_equal(v32.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 2], tol=10 * tol)
1026
+
1027
+ assert_np_equal(v40.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 0], tol=10 * tol)
1028
+ assert_np_equal(v41.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 1], tol=10 * tol)
1029
+ assert_np_equal(v42.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 2], tol=10 * tol)
1030
+ assert_np_equal(v43.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 3], tol=10 * tol)
1031
+
1032
+ assert_np_equal(v50.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 0], tol=10 * tol)
1033
+ assert_np_equal(v51.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 1], tol=10 * tol)
1034
+ assert_np_equal(v52.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 2], tol=10 * tol)
1035
+ assert_np_equal(v53.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 3], tol=10 * tol)
1036
+ assert_np_equal(v54.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 4], tol=10 * tol)
1037
+
1038
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1039
+
1040
+ if dtype in np_float_types:
1041
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43]):
1042
+ tape.backward(loss=l)
1043
+ sgrad = tape.gradients[s].numpy()[0]
1044
+ assert_np_equal(sgrad, 2 * incmps[i], tol=10 * tol)
1045
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4]])
1046
+ expected_grads = np.zeros_like(allgrads)
1047
+ expected_grads[i] = s.numpy()[0] * 2
1048
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
1049
+ tape.zero()
1050
+
1051
+
1052
+ def test_cw_multiplication(test, device, dtype, register_kernels=False):
1053
+ rng = np.random.default_rng(123)
1054
+
1055
+ tol = {
1056
+ np.float16: 5.0e-3,
1057
+ np.float32: 1.0e-6,
1058
+ np.float64: 1.0e-8,
1059
+ }.get(dtype, 0)
1060
+
1061
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1062
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1063
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1064
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1065
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1066
+
1067
+ def check_cw_mul(
1068
+ s2: wp.array(dtype=vec2),
1069
+ s3: wp.array(dtype=vec3),
1070
+ s4: wp.array(dtype=vec4),
1071
+ s5: wp.array(dtype=vec5),
1072
+ v2: wp.array(dtype=vec2),
1073
+ v3: wp.array(dtype=vec3),
1074
+ v4: wp.array(dtype=vec4),
1075
+ v5: wp.array(dtype=vec5),
1076
+ v20: wp.array(dtype=wptype),
1077
+ v21: wp.array(dtype=wptype),
1078
+ v30: wp.array(dtype=wptype),
1079
+ v31: wp.array(dtype=wptype),
1080
+ v32: wp.array(dtype=wptype),
1081
+ v40: wp.array(dtype=wptype),
1082
+ v41: wp.array(dtype=wptype),
1083
+ v42: wp.array(dtype=wptype),
1084
+ v43: wp.array(dtype=wptype),
1085
+ v50: wp.array(dtype=wptype),
1086
+ v51: wp.array(dtype=wptype),
1087
+ v52: wp.array(dtype=wptype),
1088
+ v53: wp.array(dtype=wptype),
1089
+ v54: wp.array(dtype=wptype),
1090
+ ):
1091
+ v2result = wp.cw_mul(s2[0], v2[0])
1092
+ v3result = wp.cw_mul(s3[0], v3[0])
1093
+ v4result = wp.cw_mul(s4[0], v4[0])
1094
+ v5result = wp.cw_mul(s5[0], v5[0])
1095
+
1096
+ v20[0] = wptype(2) * v2result[0]
1097
+ v21[0] = wptype(2) * v2result[1]
1098
+
1099
+ v30[0] = wptype(2) * v3result[0]
1100
+ v31[0] = wptype(2) * v3result[1]
1101
+ v32[0] = wptype(2) * v3result[2]
1102
+
1103
+ v40[0] = wptype(2) * v4result[0]
1104
+ v41[0] = wptype(2) * v4result[1]
1105
+ v42[0] = wptype(2) * v4result[2]
1106
+ v43[0] = wptype(2) * v4result[3]
1107
+
1108
+ v50[0] = wptype(2) * v5result[0]
1109
+ v51[0] = wptype(2) * v5result[1]
1110
+ v52[0] = wptype(2) * v5result[2]
1111
+ v53[0] = wptype(2) * v5result[3]
1112
+ v54[0] = wptype(2) * v5result[4]
1113
+
1114
+ kernel = getkernel(check_cw_mul, suffix=dtype.__name__)
1115
+
1116
+ if register_kernels:
1117
+ return
1118
+
1119
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1120
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1121
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1122
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1123
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1124
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1125
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1126
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1127
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1128
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1129
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1130
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1131
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1132
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1133
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1134
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1135
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1136
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1137
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1138
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1139
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1140
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1141
+ tape = wp.Tape()
1142
+ with tape:
1143
+ wp.launch(
1144
+ kernel,
1145
+ dim=1,
1146
+ inputs=[
1147
+ s2,
1148
+ s3,
1149
+ s4,
1150
+ s5,
1151
+ v2,
1152
+ v3,
1153
+ v4,
1154
+ v5,
1155
+ ],
1156
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1157
+ device=device,
1158
+ )
1159
+
1160
+ assert_np_equal(v20.numpy()[0], 2 * s2.numpy()[0, 0] * v2.numpy()[0, 0], tol=10 * tol)
1161
+ assert_np_equal(v21.numpy()[0], 2 * s2.numpy()[0, 1] * v2.numpy()[0, 1], tol=10 * tol)
1162
+
1163
+ assert_np_equal(v30.numpy()[0], 2 * s3.numpy()[0, 0] * v3.numpy()[0, 0], tol=10 * tol)
1164
+ assert_np_equal(v31.numpy()[0], 2 * s3.numpy()[0, 1] * v3.numpy()[0, 1], tol=10 * tol)
1165
+ assert_np_equal(v32.numpy()[0], 2 * s3.numpy()[0, 2] * v3.numpy()[0, 2], tol=10 * tol)
1166
+
1167
+ assert_np_equal(v40.numpy()[0], 2 * s4.numpy()[0, 0] * v4.numpy()[0, 0], tol=10 * tol)
1168
+ assert_np_equal(v41.numpy()[0], 2 * s4.numpy()[0, 1] * v4.numpy()[0, 1], tol=10 * tol)
1169
+ assert_np_equal(v42.numpy()[0], 2 * s4.numpy()[0, 2] * v4.numpy()[0, 2], tol=10 * tol)
1170
+ assert_np_equal(v43.numpy()[0], 2 * s4.numpy()[0, 3] * v4.numpy()[0, 3], tol=10 * tol)
1171
+
1172
+ assert_np_equal(v50.numpy()[0], 2 * s5.numpy()[0, 0] * v5.numpy()[0, 0], tol=10 * tol)
1173
+ assert_np_equal(v51.numpy()[0], 2 * s5.numpy()[0, 1] * v5.numpy()[0, 1], tol=10 * tol)
1174
+ assert_np_equal(v52.numpy()[0], 2 * s5.numpy()[0, 2] * v5.numpy()[0, 2], tol=10 * tol)
1175
+ assert_np_equal(v53.numpy()[0], 2 * s5.numpy()[0, 3] * v5.numpy()[0, 3], tol=10 * tol)
1176
+ assert_np_equal(v54.numpy()[0], 2 * s5.numpy()[0, 4] * v5.numpy()[0, 4], tol=10 * tol)
1177
+
1178
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1179
+ scmps = np.concatenate([v.numpy()[0] for v in [s2, s3, s4, s5]])
1180
+
1181
+ if dtype in np_float_types:
1182
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1183
+ tape.backward(loss=l)
1184
+ sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1185
+ expected_grads = np.zeros_like(sgrads)
1186
+ expected_grads[i] = incmps[i] * 2
1187
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1188
+
1189
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1190
+ expected_grads = np.zeros_like(allgrads)
1191
+ expected_grads[i] = scmps[i] * 2
1192
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
1193
+
1194
+ tape.zero()
1195
+
1196
+
1197
+ def test_scalar_division(test, device, dtype, register_kernels=False):
1198
+ rng = np.random.default_rng(123)
1199
+
1200
+ tol = {
1201
+ np.float16: 5.0e-3,
1202
+ np.float32: 1.0e-6,
1203
+ np.float64: 1.0e-8,
1204
+ }.get(dtype, 0)
1205
+
1206
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1207
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1208
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1209
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1210
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1211
+
1212
+ def check_div(
1213
+ s: wp.array(dtype=wptype),
1214
+ v2: wp.array(dtype=vec2),
1215
+ v3: wp.array(dtype=vec3),
1216
+ v4: wp.array(dtype=vec4),
1217
+ v5: wp.array(dtype=vec5),
1218
+ v20: wp.array(dtype=wptype),
1219
+ v21: wp.array(dtype=wptype),
1220
+ v30: wp.array(dtype=wptype),
1221
+ v31: wp.array(dtype=wptype),
1222
+ v32: wp.array(dtype=wptype),
1223
+ v40: wp.array(dtype=wptype),
1224
+ v41: wp.array(dtype=wptype),
1225
+ v42: wp.array(dtype=wptype),
1226
+ v43: wp.array(dtype=wptype),
1227
+ v50: wp.array(dtype=wptype),
1228
+ v51: wp.array(dtype=wptype),
1229
+ v52: wp.array(dtype=wptype),
1230
+ v53: wp.array(dtype=wptype),
1231
+ v54: wp.array(dtype=wptype),
1232
+ ):
1233
+ v2result = v2[0] / s[0]
1234
+ v3result = v3[0] / s[0]
1235
+ v4result = v4[0] / s[0]
1236
+ v5result = v5[0] / s[0]
1237
+
1238
+ v20[0] = wptype(2) * v2result[0]
1239
+ v21[0] = wptype(2) * v2result[1]
1240
+
1241
+ v30[0] = wptype(2) * v3result[0]
1242
+ v31[0] = wptype(2) * v3result[1]
1243
+ v32[0] = wptype(2) * v3result[2]
1244
+
1245
+ v40[0] = wptype(2) * v4result[0]
1246
+ v41[0] = wptype(2) * v4result[1]
1247
+ v42[0] = wptype(2) * v4result[2]
1248
+ v43[0] = wptype(2) * v4result[3]
1249
+
1250
+ v50[0] = wptype(2) * v5result[0]
1251
+ v51[0] = wptype(2) * v5result[1]
1252
+ v52[0] = wptype(2) * v5result[2]
1253
+ v53[0] = wptype(2) * v5result[3]
1254
+ v54[0] = wptype(2) * v5result[4]
1255
+
1256
+ kernel = getkernel(check_div, suffix=dtype.__name__)
1257
+
1258
+ if register_kernels:
1259
+ return
1260
+
1261
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
1262
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1263
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1264
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1265
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1266
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1267
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1268
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1269
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1270
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1271
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1272
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1273
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1274
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1275
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1276
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1277
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1278
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1279
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1280
+ tape = wp.Tape()
1281
+ with tape:
1282
+ wp.launch(
1283
+ kernel,
1284
+ dim=1,
1285
+ inputs=[
1286
+ s,
1287
+ v2,
1288
+ v3,
1289
+ v4,
1290
+ v5,
1291
+ ],
1292
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1293
+ device=device,
1294
+ )
1295
+
1296
+ if dtype in np_int_types:
1297
+ assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] // (s.numpy()[0])), tol=tol)
1298
+ assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] // (s.numpy()[0])), tol=tol)
1299
+
1300
+ assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1301
+ assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1302
+ assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1303
+
1304
+ assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1305
+ assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1306
+ assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1307
+ assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] // (s.numpy()[0])), tol=10 * tol)
1308
+
1309
+ assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1310
+ assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1311
+ assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1312
+ assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] // (s.numpy()[0])), tol=10 * tol)
1313
+ assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] // (s.numpy()[0])), tol=10 * tol)
1314
+
1315
+ else:
1316
+ assert_np_equal(v20.numpy()[0], 2 * v2.numpy()[0, 0] / (s.numpy()[0]), tol=tol)
1317
+ assert_np_equal(v21.numpy()[0], 2 * v2.numpy()[0, 1] / (s.numpy()[0]), tol=tol)
1318
+
1319
+ assert_np_equal(v30.numpy()[0], 2 * v3.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1320
+ assert_np_equal(v31.numpy()[0], 2 * v3.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1321
+ assert_np_equal(v32.numpy()[0], 2 * v3.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1322
+
1323
+ assert_np_equal(v40.numpy()[0], 2 * v4.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1324
+ assert_np_equal(v41.numpy()[0], 2 * v4.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1325
+ assert_np_equal(v42.numpy()[0], 2 * v4.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1326
+ assert_np_equal(v43.numpy()[0], 2 * v4.numpy()[0, 3] / (s.numpy()[0]), tol=10 * tol)
1327
+
1328
+ assert_np_equal(v50.numpy()[0], 2 * v5.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1329
+ assert_np_equal(v51.numpy()[0], 2 * v5.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1330
+ assert_np_equal(v52.numpy()[0], 2 * v5.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1331
+ assert_np_equal(v53.numpy()[0], 2 * v5.numpy()[0, 3] / (s.numpy()[0]), tol=10 * tol)
1332
+ assert_np_equal(v54.numpy()[0], 2 * v5.numpy()[0, 4] / (s.numpy()[0]), tol=10 * tol)
1333
+
1334
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1335
+
1336
+ if dtype in np_float_types:
1337
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1338
+ tape.backward(loss=l)
1339
+ sgrad = tape.gradients[s].numpy()[0]
1340
+
1341
+ # d/ds v/s = -v/s^2
1342
+ assert_np_equal(sgrad, -2 * incmps[i] / (s.numpy()[0] * s.numpy()[0]), tol=10 * tol)
1343
+
1344
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1345
+ expected_grads = np.zeros_like(allgrads)
1346
+ expected_grads[i] = 2 / s.numpy()[0]
1347
+
1348
+ # d/dv v/s = 1/s
1349
+ assert_np_equal(allgrads, expected_grads, tol=tol)
1350
+ tape.zero()
1351
+
1352
+
1353
+ def test_cw_division(test, device, dtype, register_kernels=False):
1354
+ rng = np.random.default_rng(123)
1355
+
1356
+ tol = {
1357
+ np.float16: 1.0e-2,
1358
+ np.float32: 1.0e-6,
1359
+ np.float64: 1.0e-8,
1360
+ }.get(dtype, 0)
1361
+
1362
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1363
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1364
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1365
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1366
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1367
+
1368
+ def check_cw_div(
1369
+ s2: wp.array(dtype=vec2),
1370
+ s3: wp.array(dtype=vec3),
1371
+ s4: wp.array(dtype=vec4),
1372
+ s5: wp.array(dtype=vec5),
1373
+ v2: wp.array(dtype=vec2),
1374
+ v3: wp.array(dtype=vec3),
1375
+ v4: wp.array(dtype=vec4),
1376
+ v5: wp.array(dtype=vec5),
1377
+ v20: wp.array(dtype=wptype),
1378
+ v21: wp.array(dtype=wptype),
1379
+ v30: wp.array(dtype=wptype),
1380
+ v31: wp.array(dtype=wptype),
1381
+ v32: wp.array(dtype=wptype),
1382
+ v40: wp.array(dtype=wptype),
1383
+ v41: wp.array(dtype=wptype),
1384
+ v42: wp.array(dtype=wptype),
1385
+ v43: wp.array(dtype=wptype),
1386
+ v50: wp.array(dtype=wptype),
1387
+ v51: wp.array(dtype=wptype),
1388
+ v52: wp.array(dtype=wptype),
1389
+ v53: wp.array(dtype=wptype),
1390
+ v54: wp.array(dtype=wptype),
1391
+ ):
1392
+ v2result = wp.cw_div(v2[0], s2[0])
1393
+ v3result = wp.cw_div(v3[0], s3[0])
1394
+ v4result = wp.cw_div(v4[0], s4[0])
1395
+ v5result = wp.cw_div(v5[0], s5[0])
1396
+
1397
+ v20[0] = wptype(2) * v2result[0]
1398
+ v21[0] = wptype(2) * v2result[1]
1399
+
1400
+ v30[0] = wptype(2) * v3result[0]
1401
+ v31[0] = wptype(2) * v3result[1]
1402
+ v32[0] = wptype(2) * v3result[2]
1403
+
1404
+ v40[0] = wptype(2) * v4result[0]
1405
+ v41[0] = wptype(2) * v4result[1]
1406
+ v42[0] = wptype(2) * v4result[2]
1407
+ v43[0] = wptype(2) * v4result[3]
1408
+
1409
+ v50[0] = wptype(2) * v5result[0]
1410
+ v51[0] = wptype(2) * v5result[1]
1411
+ v52[0] = wptype(2) * v5result[2]
1412
+ v53[0] = wptype(2) * v5result[3]
1413
+ v54[0] = wptype(2) * v5result[4]
1414
+
1415
+ kernel = getkernel(check_cw_div, suffix=dtype.__name__)
1416
+
1417
+ if register_kernels:
1418
+ return
1419
+
1420
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1421
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1422
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1423
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1424
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1425
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1426
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1427
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1428
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1429
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1430
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1431
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1432
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1433
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1434
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1435
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1436
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1437
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1438
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1439
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1440
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1441
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1442
+ tape = wp.Tape()
1443
+ with tape:
1444
+ wp.launch(
1445
+ kernel,
1446
+ dim=1,
1447
+ inputs=[
1448
+ s2,
1449
+ s3,
1450
+ s4,
1451
+ s5,
1452
+ v2,
1453
+ v3,
1454
+ v4,
1455
+ v5,
1456
+ ],
1457
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1458
+ device=device,
1459
+ )
1460
+
1461
+ if dtype in np_int_types:
1462
+ assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] // s2.numpy()[0, 0]), tol=tol)
1463
+ assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] // s2.numpy()[0, 1]), tol=tol)
1464
+
1465
+ assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] // s3.numpy()[0, 0]), tol=tol)
1466
+ assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] // s3.numpy()[0, 1]), tol=tol)
1467
+ assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] // s3.numpy()[0, 2]), tol=tol)
1468
+
1469
+ assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] // s4.numpy()[0, 0]), tol=tol)
1470
+ assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] // s4.numpy()[0, 1]), tol=tol)
1471
+ assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] // s4.numpy()[0, 2]), tol=tol)
1472
+ assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] // s4.numpy()[0, 3]), tol=tol)
1473
+
1474
+ assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] // s5.numpy()[0, 0]), tol=tol)
1475
+ assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] // s5.numpy()[0, 1]), tol=tol)
1476
+ assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] // s5.numpy()[0, 2]), tol=tol)
1477
+ assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] // s5.numpy()[0, 3]), tol=tol)
1478
+ assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] // s5.numpy()[0, 4]), tol=tol)
1479
+ else:
1480
+ assert_np_equal(v20.numpy()[0], 2 * v2.numpy()[0, 0] / s2.numpy()[0, 0], tol=tol)
1481
+ assert_np_equal(v21.numpy()[0], 2 * v2.numpy()[0, 1] / s2.numpy()[0, 1], tol=tol)
1482
+
1483
+ assert_np_equal(v30.numpy()[0], 2 * v3.numpy()[0, 0] / s3.numpy()[0, 0], tol=tol)
1484
+ assert_np_equal(v31.numpy()[0], 2 * v3.numpy()[0, 1] / s3.numpy()[0, 1], tol=tol)
1485
+ assert_np_equal(v32.numpy()[0], 2 * v3.numpy()[0, 2] / s3.numpy()[0, 2], tol=tol)
1486
+
1487
+ assert_np_equal(v40.numpy()[0], 2 * v4.numpy()[0, 0] / s4.numpy()[0, 0], tol=tol)
1488
+ assert_np_equal(v41.numpy()[0], 2 * v4.numpy()[0, 1] / s4.numpy()[0, 1], tol=tol)
1489
+ assert_np_equal(v42.numpy()[0], 2 * v4.numpy()[0, 2] / s4.numpy()[0, 2], tol=tol)
1490
+ assert_np_equal(v43.numpy()[0], 2 * v4.numpy()[0, 3] / s4.numpy()[0, 3], tol=tol)
1491
+
1492
+ assert_np_equal(v50.numpy()[0], 2 * v5.numpy()[0, 0] / s5.numpy()[0, 0], tol=tol)
1493
+ assert_np_equal(v51.numpy()[0], 2 * v5.numpy()[0, 1] / s5.numpy()[0, 1], tol=tol)
1494
+ assert_np_equal(v52.numpy()[0], 2 * v5.numpy()[0, 2] / s5.numpy()[0, 2], tol=tol)
1495
+ assert_np_equal(v53.numpy()[0], 2 * v5.numpy()[0, 3] / s5.numpy()[0, 3], tol=tol)
1496
+ assert_np_equal(v54.numpy()[0], 2 * v5.numpy()[0, 4] / s5.numpy()[0, 4], tol=tol)
1497
+
1498
+ if dtype in np_float_types:
1499
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1500
+ scmps = np.concatenate([v.numpy()[0] for v in [s2, s3, s4, s5]])
1501
+
1502
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1503
+ tape.backward(loss=l)
1504
+ sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1505
+ expected_grads = np.zeros_like(sgrads)
1506
+
1507
+ # d/ds v/s = -v/s^2
1508
+ expected_grads[i] = -incmps[i] * 2 / (scmps[i] * scmps[i])
1509
+ assert_np_equal(sgrads, expected_grads, tol=20 * tol)
1510
+
1511
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1512
+ expected_grads = np.zeros_like(allgrads)
1513
+
1514
+ # d/dv v/s = 1/s
1515
+ expected_grads[i] = 2 / scmps[i]
1516
+ assert_np_equal(allgrads, expected_grads, tol=tol)
1517
+
1518
+ tape.zero()
1519
+
1520
+
1521
+ def test_addition(test, device, dtype, register_kernels=False):
1522
+ rng = np.random.default_rng(123)
1523
+
1524
+ tol = {
1525
+ np.float16: 5.0e-3,
1526
+ np.float32: 1.0e-6,
1527
+ np.float64: 1.0e-8,
1528
+ }.get(dtype, 0)
1529
+
1530
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1531
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1532
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1533
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1534
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1535
+
1536
+ def check_add(
1537
+ s2: wp.array(dtype=vec2),
1538
+ s3: wp.array(dtype=vec3),
1539
+ s4: wp.array(dtype=vec4),
1540
+ s5: wp.array(dtype=vec5),
1541
+ v2: wp.array(dtype=vec2),
1542
+ v3: wp.array(dtype=vec3),
1543
+ v4: wp.array(dtype=vec4),
1544
+ v5: wp.array(dtype=vec5),
1545
+ v20: wp.array(dtype=wptype),
1546
+ v21: wp.array(dtype=wptype),
1547
+ v30: wp.array(dtype=wptype),
1548
+ v31: wp.array(dtype=wptype),
1549
+ v32: wp.array(dtype=wptype),
1550
+ v40: wp.array(dtype=wptype),
1551
+ v41: wp.array(dtype=wptype),
1552
+ v42: wp.array(dtype=wptype),
1553
+ v43: wp.array(dtype=wptype),
1554
+ v50: wp.array(dtype=wptype),
1555
+ v51: wp.array(dtype=wptype),
1556
+ v52: wp.array(dtype=wptype),
1557
+ v53: wp.array(dtype=wptype),
1558
+ v54: wp.array(dtype=wptype),
1559
+ ):
1560
+ v2result = v2[0] + s2[0]
1561
+ v3result = v3[0] + s3[0]
1562
+ v4result = v4[0] + s4[0]
1563
+ v5result = v5[0] + s5[0]
1564
+
1565
+ v20[0] = wptype(2) * v2result[0]
1566
+ v21[0] = wptype(2) * v2result[1]
1567
+
1568
+ v30[0] = wptype(2) * v3result[0]
1569
+ v31[0] = wptype(2) * v3result[1]
1570
+ v32[0] = wptype(2) * v3result[2]
1571
+
1572
+ v40[0] = wptype(2) * v4result[0]
1573
+ v41[0] = wptype(2) * v4result[1]
1574
+ v42[0] = wptype(2) * v4result[2]
1575
+ v43[0] = wptype(2) * v4result[3]
1576
+
1577
+ v50[0] = wptype(2) * v5result[0]
1578
+ v51[0] = wptype(2) * v5result[1]
1579
+ v52[0] = wptype(2) * v5result[2]
1580
+ v53[0] = wptype(2) * v5result[3]
1581
+ v54[0] = wptype(2) * v5result[4]
1582
+
1583
+ kernel = getkernel(check_add, suffix=dtype.__name__)
1584
+
1585
+ if register_kernels:
1586
+ return
1587
+
1588
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1589
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1590
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1591
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1592
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1593
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1594
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1595
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1596
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1597
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1598
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1599
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1600
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1601
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1602
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1603
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1604
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1605
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1606
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1607
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1608
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1609
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1610
+ tape = wp.Tape()
1611
+ with tape:
1612
+ wp.launch(
1613
+ kernel,
1614
+ dim=1,
1615
+ inputs=[
1616
+ s2,
1617
+ s3,
1618
+ s4,
1619
+ s5,
1620
+ v2,
1621
+ v3,
1622
+ v4,
1623
+ v5,
1624
+ ],
1625
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1626
+ device=device,
1627
+ )
1628
+
1629
+ assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] + s2.numpy()[0, 0]), tol=tol)
1630
+ assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] + s2.numpy()[0, 1]), tol=tol)
1631
+
1632
+ assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] + s3.numpy()[0, 0]), tol=tol)
1633
+ assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] + s3.numpy()[0, 1]), tol=tol)
1634
+ assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] + s3.numpy()[0, 2]), tol=tol)
1635
+
1636
+ assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] + s4.numpy()[0, 0]), tol=tol)
1637
+ assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] + s4.numpy()[0, 1]), tol=tol)
1638
+ assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] + s4.numpy()[0, 2]), tol=tol)
1639
+ assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] + s4.numpy()[0, 3]), tol=tol)
1640
+
1641
+ assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] + s5.numpy()[0, 0]), tol=tol)
1642
+ assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] + s5.numpy()[0, 1]), tol=tol)
1643
+ assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] + s5.numpy()[0, 2]), tol=tol)
1644
+ assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] + s5.numpy()[0, 3]), tol=tol)
1645
+ assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] + s5.numpy()[0, 4]), tol=2 * tol)
1646
+
1647
+ if dtype in np_float_types:
1648
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1649
+ tape.backward(loss=l)
1650
+ sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1651
+ expected_grads = np.zeros_like(sgrads)
1652
+
1653
+ expected_grads[i] = 2
1654
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1655
+
1656
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1657
+ assert_np_equal(allgrads, expected_grads, tol=tol)
1658
+
1659
+ tape.zero()
1660
+
1661
+
1662
+ def test_dotproduct(test, device, dtype, register_kernels=False):
1663
+ rng = np.random.default_rng(123)
1664
+
1665
+ tol = {
1666
+ np.float16: 1.0e-2,
1667
+ np.float32: 1.0e-6,
1668
+ np.float64: 1.0e-8,
1669
+ }.get(dtype, 0)
1670
+
1671
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1672
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1673
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1674
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1675
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1676
+
1677
+ def check_dot(
1678
+ s2: wp.array(dtype=vec2),
1679
+ s3: wp.array(dtype=vec3),
1680
+ s4: wp.array(dtype=vec4),
1681
+ s5: wp.array(dtype=vec5),
1682
+ v2: wp.array(dtype=vec2),
1683
+ v3: wp.array(dtype=vec3),
1684
+ v4: wp.array(dtype=vec4),
1685
+ v5: wp.array(dtype=vec5),
1686
+ dot2: wp.array(dtype=wptype),
1687
+ dot3: wp.array(dtype=wptype),
1688
+ dot4: wp.array(dtype=wptype),
1689
+ dot5: wp.array(dtype=wptype),
1690
+ ):
1691
+ dot2[0] = wptype(2) * wp.dot(v2[0], s2[0])
1692
+ dot3[0] = wptype(2) * wp.dot(v3[0], s3[0])
1693
+ dot4[0] = wptype(2) * wp.dot(v4[0], s4[0])
1694
+ dot5[0] = wptype(2) * wp.dot(v5[0], s5[0])
1695
+
1696
+ kernel = getkernel(check_dot, suffix=dtype.__name__)
1697
+
1698
+ if register_kernels:
1699
+ return
1700
+
1701
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1702
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1703
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1704
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1705
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1706
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1707
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1708
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1709
+ dot2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1710
+ dot3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1711
+ dot4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1712
+ dot5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1713
+ tape = wp.Tape()
1714
+ with tape:
1715
+ wp.launch(
1716
+ kernel,
1717
+ dim=1,
1718
+ inputs=[
1719
+ s2,
1720
+ s3,
1721
+ s4,
1722
+ s5,
1723
+ v2,
1724
+ v3,
1725
+ v4,
1726
+ v5,
1727
+ ],
1728
+ outputs=[dot2, dot3, dot4, dot5],
1729
+ device=device,
1730
+ )
1731
+
1732
+ assert_np_equal(dot2.numpy()[0], 2.0 * (v2.numpy() * s2.numpy()).sum(), tol=10 * tol)
1733
+ assert_np_equal(dot3.numpy()[0], 2.0 * (v3.numpy() * s3.numpy()).sum(), tol=10 * tol)
1734
+ assert_np_equal(dot4.numpy()[0], 2.0 * (v4.numpy() * s4.numpy()).sum(), tol=10 * tol)
1735
+ assert_np_equal(dot5.numpy()[0], 2.0 * (v5.numpy() * s5.numpy()).sum(), tol=10 * tol)
1736
+
1737
+ if dtype in np_float_types:
1738
+ tape.backward(loss=dot2)
1739
+ sgrads = tape.gradients[s2].numpy()[0]
1740
+ expected_grads = 2.0 * v2.numpy()[0]
1741
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1742
+
1743
+ vgrads = tape.gradients[v2].numpy()[0]
1744
+ expected_grads = 2.0 * s2.numpy()[0]
1745
+ assert_np_equal(vgrads, expected_grads, tol=tol)
1746
+
1747
+ tape.zero()
1748
+
1749
+ tape.backward(loss=dot3)
1750
+ sgrads = tape.gradients[s3].numpy()[0]
1751
+ expected_grads = 2.0 * v3.numpy()[0]
1752
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1753
+
1754
+ vgrads = tape.gradients[v3].numpy()[0]
1755
+ expected_grads = 2.0 * s3.numpy()[0]
1756
+ assert_np_equal(vgrads, expected_grads, tol=tol)
1757
+
1758
+ tape.zero()
1759
+
1760
+ tape.backward(loss=dot4)
1761
+ sgrads = tape.gradients[s4].numpy()[0]
1762
+ expected_grads = 2.0 * v4.numpy()[0]
1763
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1764
+
1765
+ vgrads = tape.gradients[v4].numpy()[0]
1766
+ expected_grads = 2.0 * s4.numpy()[0]
1767
+ assert_np_equal(vgrads, expected_grads, tol=tol)
1768
+
1769
+ tape.zero()
1770
+
1771
+ tape.backward(loss=dot5)
1772
+ sgrads = tape.gradients[s5].numpy()[0]
1773
+ expected_grads = 2.0 * v5.numpy()[0]
1774
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1775
+
1776
+ vgrads = tape.gradients[v5].numpy()[0]
1777
+ expected_grads = 2.0 * s5.numpy()[0]
1778
+ assert_np_equal(vgrads, expected_grads, tol=10 * tol)
1779
+
1780
+ tape.zero()
1781
+
1782
+
1783
+ def test_equivalent_types(test, device, dtype, register_kernels=False):
1784
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1785
+
1786
+ # vector types
1787
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1788
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1789
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1790
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1791
+
1792
+ # vector types equivalent to the above
1793
+ vec2_equiv = wp.types.vector(length=2, dtype=wptype)
1794
+ vec3_equiv = wp.types.vector(length=3, dtype=wptype)
1795
+ vec4_equiv = wp.types.vector(length=4, dtype=wptype)
1796
+ vec5_equiv = wp.types.vector(length=5, dtype=wptype)
1797
+
1798
+ # declare kernel with original types
1799
+ def check_equivalence(
1800
+ v2: vec2,
1801
+ v3: vec3,
1802
+ v4: vec4,
1803
+ v5: vec5,
1804
+ ):
1805
+ wp.expect_eq(v2, vec2(wptype(1), wptype(2)))
1806
+ wp.expect_eq(v3, vec3(wptype(1), wptype(2), wptype(3)))
1807
+ wp.expect_eq(v4, vec4(wptype(1), wptype(2), wptype(3), wptype(4)))
1808
+ wp.expect_eq(v5, vec5(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1809
+
1810
+ wp.expect_eq(v2, vec2_equiv(wptype(1), wptype(2)))
1811
+ wp.expect_eq(v3, vec3_equiv(wptype(1), wptype(2), wptype(3)))
1812
+ wp.expect_eq(v4, vec4_equiv(wptype(1), wptype(2), wptype(3), wptype(4)))
1813
+ wp.expect_eq(v5, vec5_equiv(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1814
+
1815
+ kernel = getkernel(check_equivalence, suffix=dtype.__name__)
1816
+
1817
+ if register_kernels:
1818
+ return
1819
+
1820
+ # call kernel with equivalent types
1821
+ v2 = vec2_equiv(1, 2)
1822
+ v3 = vec3_equiv(1, 2, 3)
1823
+ v4 = vec4_equiv(1, 2, 3, 4)
1824
+ v5 = vec5_equiv(1, 2, 3, 4, 5)
1825
+
1826
+ wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5], device=device)
1827
+
1828
+
1829
+ def test_conversions(test, device, dtype, register_kernels=False):
1830
+ def check_vectors_equal(
1831
+ v0: wp.vec3,
1832
+ v1: wp.vec3,
1833
+ v2: wp.vec3,
1834
+ v3: wp.vec3,
1835
+ ):
1836
+ wp.expect_eq(v1, v0)
1837
+ wp.expect_eq(v2, v0)
1838
+ wp.expect_eq(v3, v0)
1839
+
1840
+ kernel = getkernel(check_vectors_equal, suffix=dtype.__name__)
1841
+
1842
+ if register_kernels:
1843
+ return
1844
+
1845
+ v0 = wp.vec3(1, 2, 3)
1846
+
1847
+ # test explicit conversions - constructing vectors from different containers
1848
+ v1 = wp.vec3((1, 2, 3))
1849
+ v2 = wp.vec3([1, 2, 3])
1850
+ v3 = wp.vec3(np.array([1, 2, 3], dtype=dtype))
1851
+
1852
+ wp.launch(kernel, dim=1, inputs=[v0, v1, v2, v3], device=device)
1853
+
1854
+ # test implicit conversions - passing different containers as vectors to wp.launch()
1855
+ v1 = (1, 2, 3)
1856
+ v2 = [1, 2, 3]
1857
+ v3 = np.array([1, 2, 3], dtype=dtype)
1858
+
1859
+ wp.launch(kernel, dim=1, inputs=[v0, v1, v2, v3], device=device)
1860
+
1861
+
1862
+ def test_constants(test, device, dtype, register_kernels=False):
1863
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1864
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1865
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1866
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1867
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1868
+
1869
+ cv2 = wp.constant(vec2(1, 2))
1870
+ cv3 = wp.constant(vec3(1, 2, 3))
1871
+ cv4 = wp.constant(vec4(1, 2, 3, 4))
1872
+ cv5 = wp.constant(vec5(1, 2, 3, 4, 5))
1873
+
1874
+ def check_vector_constants():
1875
+ wp.expect_eq(cv2, vec2(wptype(1), wptype(2)))
1876
+ wp.expect_eq(cv3, vec3(wptype(1), wptype(2), wptype(3)))
1877
+ wp.expect_eq(cv4, vec4(wptype(1), wptype(2), wptype(3), wptype(4)))
1878
+ wp.expect_eq(cv5, vec5(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1879
+
1880
+ kernel = getkernel(check_vector_constants, suffix=dtype.__name__)
1881
+
1882
+ if register_kernels:
1883
+ return
1884
+
1885
+ wp.launch(kernel, dim=1, inputs=[], device=device)
1886
+
1887
+
1888
+ def test_minmax(test, device, dtype, register_kernels=False):
1889
+ rng = np.random.default_rng(123)
1890
+
1891
+ # \TODO: not quite sure why, but the numbers are off for 16 bit float
1892
+ # on the cpu (but not cuda). This is probably just the sketchy float16
1893
+ # arithmetic I implemented to get all this stuff working, so
1894
+ # hopefully that can be fixed when we do that correctly.
1895
+ tol = {
1896
+ np.float16: 1.0e-2,
1897
+ }.get(dtype, 0)
1898
+
1899
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1900
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1901
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1902
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1903
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1904
+
1905
+ # \TODO: Also not quite sure why: this kernel compiles incredibly
1906
+ # slowly though...
1907
+ def check_vec_min_max(
1908
+ a: wp.array(dtype=wptype, ndim=2),
1909
+ b: wp.array(dtype=wptype, ndim=2),
1910
+ mins: wp.array(dtype=wptype, ndim=2),
1911
+ maxs: wp.array(dtype=wptype, ndim=2),
1912
+ ):
1913
+ for i in range(10):
1914
+ # multiplying by 2 so we've got something to backpropagate:
1915
+ a2read = vec2(a[i, 0], a[i, 1])
1916
+ b2read = vec2(b[i, 0], b[i, 1])
1917
+ c2 = wptype(2) * wp.min(a2read, b2read)
1918
+ d2 = wptype(2) * wp.max(a2read, b2read)
1919
+
1920
+ a3read = vec3(a[i, 2], a[i, 3], a[i, 4])
1921
+ b3read = vec3(b[i, 2], b[i, 3], b[i, 4])
1922
+ c3 = wptype(2) * wp.min(a3read, b3read)
1923
+ d3 = wptype(2) * wp.max(a3read, b3read)
1924
+
1925
+ a4read = vec4(a[i, 5], a[i, 6], a[i, 7], a[i, 8])
1926
+ b4read = vec4(b[i, 5], b[i, 6], b[i, 7], b[i, 8])
1927
+ c4 = wptype(2) * wp.min(a4read, b4read)
1928
+ d4 = wptype(2) * wp.max(a4read, b4read)
1929
+
1930
+ a5read = vec5(a[i, 9], a[i, 10], a[i, 11], a[i, 12], a[i, 13])
1931
+ b5read = vec5(b[i, 9], b[i, 10], b[i, 11], b[i, 12], b[i, 13])
1932
+ c5 = wptype(2) * wp.min(a5read, b5read)
1933
+ d5 = wptype(2) * wp.max(a5read, b5read)
1934
+
1935
+ mins[i, 0] = c2[0]
1936
+ mins[i, 1] = c2[1]
1937
+
1938
+ mins[i, 2] = c3[0]
1939
+ mins[i, 3] = c3[1]
1940
+ mins[i, 4] = c3[2]
1941
+
1942
+ mins[i, 5] = c4[0]
1943
+ mins[i, 6] = c4[1]
1944
+ mins[i, 7] = c4[2]
1945
+ mins[i, 8] = c4[3]
1946
+
1947
+ mins[i, 9] = c5[0]
1948
+ mins[i, 10] = c5[1]
1949
+ mins[i, 11] = c5[2]
1950
+ mins[i, 12] = c5[3]
1951
+ mins[i, 13] = c5[4]
1952
+
1953
+ maxs[i, 0] = d2[0]
1954
+ maxs[i, 1] = d2[1]
1955
+
1956
+ maxs[i, 2] = d3[0]
1957
+ maxs[i, 3] = d3[1]
1958
+ maxs[i, 4] = d3[2]
1959
+
1960
+ maxs[i, 5] = d4[0]
1961
+ maxs[i, 6] = d4[1]
1962
+ maxs[i, 7] = d4[2]
1963
+ maxs[i, 8] = d4[3]
1964
+
1965
+ maxs[i, 9] = d5[0]
1966
+ maxs[i, 10] = d5[1]
1967
+ maxs[i, 11] = d5[2]
1968
+ maxs[i, 12] = d5[3]
1969
+ maxs[i, 13] = d5[4]
1970
+
1971
+ kernel = getkernel(check_vec_min_max, suffix=dtype.__name__)
1972
+ output_select_kernel = get_select_kernel2(wptype)
1973
+
1974
+ if register_kernels:
1975
+ return
1976
+
1977
+ a = wp.array(randvals(rng, (10, 14), dtype), dtype=wptype, requires_grad=True, device=device)
1978
+ b = wp.array(randvals(rng, (10, 14), dtype), dtype=wptype, requires_grad=True, device=device)
1979
+
1980
+ mins = wp.zeros((10, 14), dtype=wptype, requires_grad=True, device=device)
1981
+ maxs = wp.zeros((10, 14), dtype=wptype, requires_grad=True, device=device)
1982
+
1983
+ tape = wp.Tape()
1984
+ with tape:
1985
+ wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
1986
+
1987
+ assert_np_equal(mins.numpy(), 2 * np.minimum(a.numpy(), b.numpy()), tol=tol)
1988
+ assert_np_equal(maxs.numpy(), 2 * np.maximum(a.numpy(), b.numpy()), tol=tol)
1989
+
1990
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1991
+ if dtype in np_float_types:
1992
+ for i in range(10):
1993
+ for j in range(14):
1994
+ tape = wp.Tape()
1995
+ with tape:
1996
+ wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
1997
+ wp.launch(output_select_kernel, dim=1, inputs=[mins, i, j], outputs=[out], device=device)
1998
+
1999
+ tape.backward(loss=out)
2000
+ expected = np.zeros_like(a.numpy())
2001
+ expected[i, j] = 2 if (a.numpy()[i, j] < b.numpy()[i, j]) else 0
2002
+ assert_np_equal(tape.gradients[a].numpy(), expected, tol=tol)
2003
+ expected[i, j] = 2 if (b.numpy()[i, j] < a.numpy()[i, j]) else 0
2004
+ assert_np_equal(tape.gradients[b].numpy(), expected, tol=tol)
2005
+ tape.zero()
2006
+
2007
+ tape = wp.Tape()
2008
+ with tape:
2009
+ wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
2010
+ wp.launch(output_select_kernel, dim=1, inputs=[maxs, i, j], outputs=[out], device=device)
2011
+
2012
+ tape.backward(loss=out)
2013
+ expected = np.zeros_like(a.numpy())
2014
+ expected[i, j] = 2 if (a.numpy()[i, j] > b.numpy()[i, j]) else 0
2015
+ assert_np_equal(tape.gradients[a].numpy(), expected, tol=tol)
2016
+ expected[i, j] = 2 if (b.numpy()[i, j] > a.numpy()[i, j]) else 0
2017
+ assert_np_equal(tape.gradients[b].numpy(), expected, tol=tol)
2018
+ tape.zero()
2019
+
2020
+
2021
+ devices = get_test_devices()
2022
+
2023
+
2024
+ class TestVecScalarOps(unittest.TestCase):
2025
+ pass
2026
+
2027
+
2028
+ for dtype in np_scalar_types:
2029
+ add_function_test(TestVecScalarOps, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
2030
+ add_function_test(TestVecScalarOps, f"test_components_{dtype.__name__}", test_components, devices=None, dtype=dtype)
2031
+ add_function_test(
2032
+ TestVecScalarOps, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2033
+ )
2034
+ add_function_test_register_kernel(
2035
+ TestVecScalarOps, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
2036
+ )
2037
+ add_function_test_register_kernel(
2038
+ TestVecScalarOps,
2039
+ f"test_anon_type_instance_{dtype.__name__}",
2040
+ test_anon_type_instance,
2041
+ devices=devices,
2042
+ dtype=dtype,
2043
+ )
2044
+ add_function_test_register_kernel(
2045
+ TestVecScalarOps, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2046
+ )
2047
+ add_function_test_register_kernel(
2048
+ TestVecScalarOps, f"test_equality_{dtype.__name__}", test_equality, devices=devices, dtype=dtype
2049
+ )
2050
+ add_function_test_register_kernel(
2051
+ TestVecScalarOps,
2052
+ f"test_scalar_multiplication_{dtype.__name__}",
2053
+ test_scalar_multiplication,
2054
+ devices=devices,
2055
+ dtype=dtype,
2056
+ )
2057
+ add_function_test_register_kernel(
2058
+ TestVecScalarOps,
2059
+ f"test_scalar_multiplication_rightmul_{dtype.__name__}",
2060
+ test_scalar_multiplication_rightmul,
2061
+ devices=devices,
2062
+ dtype=dtype,
2063
+ )
2064
+ add_function_test_register_kernel(
2065
+ TestVecScalarOps,
2066
+ f"test_cw_multiplication_{dtype.__name__}",
2067
+ test_cw_multiplication,
2068
+ devices=devices,
2069
+ dtype=dtype,
2070
+ )
2071
+ add_function_test_register_kernel(
2072
+ TestVecScalarOps, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2073
+ )
2074
+ add_function_test_register_kernel(
2075
+ TestVecScalarOps, f"test_cw_division_{dtype.__name__}", test_cw_division, devices=devices, dtype=dtype
2076
+ )
2077
+ add_function_test_register_kernel(
2078
+ TestVecScalarOps, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2079
+ )
2080
+ add_function_test_register_kernel(
2081
+ TestVecScalarOps, f"test_dotproduct_{dtype.__name__}", test_dotproduct, devices=devices, dtype=dtype
2082
+ )
2083
+ add_function_test_register_kernel(
2084
+ TestVecScalarOps, f"test_equivalent_types_{dtype.__name__}", test_equivalent_types, devices=devices, dtype=dtype
2085
+ )
2086
+ add_function_test_register_kernel(
2087
+ TestVecScalarOps, f"test_conversions_{dtype.__name__}", test_conversions, devices=devices, dtype=dtype
2088
+ )
2089
+ add_function_test_register_kernel(
2090
+ TestVecScalarOps, f"test_constants_{dtype.__name__}", test_constants, devices=devices, dtype=dtype
2091
+ )
2092
+
2093
+ # the kernels in this test compile incredibly slowly...
2094
+ # add_function_test_register_kernel(TestVecScalarOps, f"test_minmax_{dtype.__name__}", test_minmax, devices=devices, dtype=dtype)
2095
+
2096
+
2097
+ if __name__ == "__main__":
2098
+ wp.build.clear_kernel_cache()
2099
+ unittest.main(verbosity=2, failfast=True)