warp-lang 1.0.2__py3-none-win_amd64.whl → 1.2.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (356) hide show
  1. warp/__init__.py +108 -97
  2. warp/__init__.pyi +1 -1
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +88 -113
  6. warp/build_dll.py +383 -375
  7. warp/builtins.py +3693 -3354
  8. warp/codegen.py +2925 -2792
  9. warp/config.py +40 -36
  10. warp/constants.py +49 -45
  11. warp/context.py +5409 -5102
  12. warp/dlpack.py +442 -442
  13. warp/examples/__init__.py +16 -16
  14. warp/examples/assets/bear.usd +0 -0
  15. warp/examples/assets/bunny.usd +0 -0
  16. warp/examples/assets/cartpole.urdf +110 -110
  17. warp/examples/assets/crazyflie.usd +0 -0
  18. warp/examples/assets/cube.usd +0 -0
  19. warp/examples/assets/nv_ant.xml +92 -92
  20. warp/examples/assets/nv_humanoid.xml +183 -183
  21. warp/examples/assets/quadruped.urdf +267 -267
  22. warp/examples/assets/rocks.nvdb +0 -0
  23. warp/examples/assets/rocks.usd +0 -0
  24. warp/examples/assets/sphere.usd +0 -0
  25. warp/examples/benchmarks/benchmark_api.py +381 -383
  26. warp/examples/benchmarks/benchmark_cloth.py +278 -277
  27. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
  28. warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
  29. warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
  30. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
  31. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
  32. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
  33. warp/examples/benchmarks/benchmark_cloth_warp.py +145 -146
  34. warp/examples/benchmarks/benchmark_launches.py +293 -295
  35. warp/examples/browse.py +29 -29
  36. warp/examples/core/example_dem.py +232 -219
  37. warp/examples/core/example_fluid.py +291 -267
  38. warp/examples/core/example_graph_capture.py +142 -126
  39. warp/examples/core/example_marching_cubes.py +186 -174
  40. warp/examples/core/example_mesh.py +172 -155
  41. warp/examples/core/example_mesh_intersect.py +203 -193
  42. warp/examples/core/example_nvdb.py +174 -170
  43. warp/examples/core/example_raycast.py +103 -90
  44. warp/examples/core/example_raymarch.py +197 -178
  45. warp/examples/core/example_render_opengl.py +183 -141
  46. warp/examples/core/example_sph.py +403 -387
  47. warp/examples/core/example_torch.py +219 -181
  48. warp/examples/core/example_wave.py +261 -248
  49. warp/examples/fem/bsr_utils.py +378 -380
  50. warp/examples/fem/example_apic_fluid.py +432 -389
  51. warp/examples/fem/example_burgers.py +262 -0
  52. warp/examples/fem/example_convection_diffusion.py +180 -168
  53. warp/examples/fem/example_convection_diffusion_dg.py +217 -209
  54. warp/examples/fem/example_deformed_geometry.py +175 -159
  55. warp/examples/fem/example_diffusion.py +199 -173
  56. warp/examples/fem/example_diffusion_3d.py +178 -152
  57. warp/examples/fem/example_diffusion_mgpu.py +219 -214
  58. warp/examples/fem/example_mixed_elasticity.py +242 -222
  59. warp/examples/fem/example_navier_stokes.py +257 -243
  60. warp/examples/fem/example_stokes.py +218 -192
  61. warp/examples/fem/example_stokes_transfer.py +263 -249
  62. warp/examples/fem/mesh_utils.py +133 -109
  63. warp/examples/fem/plot_utils.py +292 -287
  64. warp/examples/optim/example_bounce.py +258 -246
  65. warp/examples/optim/example_cloth_throw.py +220 -209
  66. warp/examples/optim/example_diffray.py +564 -536
  67. warp/examples/optim/example_drone.py +862 -835
  68. warp/examples/optim/example_inverse_kinematics.py +174 -168
  69. warp/examples/optim/example_inverse_kinematics_torch.py +183 -169
  70. warp/examples/optim/example_spring_cage.py +237 -231
  71. warp/examples/optim/example_trajectory.py +221 -199
  72. warp/examples/optim/example_walker.py +304 -293
  73. warp/examples/sim/example_cartpole.py +137 -129
  74. warp/examples/sim/example_cloth.py +194 -186
  75. warp/examples/sim/example_granular.py +122 -111
  76. warp/examples/sim/example_granular_collision_sdf.py +195 -186
  77. warp/examples/sim/example_jacobian_ik.py +234 -214
  78. warp/examples/sim/example_particle_chain.py +116 -105
  79. warp/examples/sim/example_quadruped.py +191 -180
  80. warp/examples/sim/example_rigid_chain.py +195 -187
  81. warp/examples/sim/example_rigid_contact.py +187 -177
  82. warp/examples/sim/example_rigid_force.py +125 -125
  83. warp/examples/sim/example_rigid_gyroscopic.py +107 -95
  84. warp/examples/sim/example_rigid_soft_contact.py +132 -122
  85. warp/examples/sim/example_soft_body.py +188 -177
  86. warp/fabric.py +337 -335
  87. warp/fem/__init__.py +61 -27
  88. warp/fem/cache.py +403 -388
  89. warp/fem/dirichlet.py +178 -179
  90. warp/fem/domain.py +262 -263
  91. warp/fem/field/__init__.py +100 -101
  92. warp/fem/field/field.py +148 -149
  93. warp/fem/field/nodal_field.py +298 -299
  94. warp/fem/field/restriction.py +22 -21
  95. warp/fem/field/test.py +180 -181
  96. warp/fem/field/trial.py +183 -183
  97. warp/fem/geometry/__init__.py +16 -19
  98. warp/fem/geometry/closest_point.py +69 -70
  99. warp/fem/geometry/deformed_geometry.py +270 -271
  100. warp/fem/geometry/element.py +748 -744
  101. warp/fem/geometry/geometry.py +184 -186
  102. warp/fem/geometry/grid_2d.py +380 -373
  103. warp/fem/geometry/grid_3d.py +437 -435
  104. warp/fem/geometry/hexmesh.py +953 -953
  105. warp/fem/geometry/nanogrid.py +455 -0
  106. warp/fem/geometry/partition.py +374 -376
  107. warp/fem/geometry/quadmesh_2d.py +532 -532
  108. warp/fem/geometry/tetmesh.py +840 -840
  109. warp/fem/geometry/trimesh_2d.py +577 -577
  110. warp/fem/integrate.py +1684 -1615
  111. warp/fem/operator.py +190 -191
  112. warp/fem/polynomial.py +214 -213
  113. warp/fem/quadrature/__init__.py +2 -2
  114. warp/fem/quadrature/pic_quadrature.py +243 -245
  115. warp/fem/quadrature/quadrature.py +295 -294
  116. warp/fem/space/__init__.py +179 -292
  117. warp/fem/space/basis_space.py +522 -489
  118. warp/fem/space/collocated_function_space.py +100 -105
  119. warp/fem/space/dof_mapper.py +236 -236
  120. warp/fem/space/function_space.py +148 -145
  121. warp/fem/space/grid_2d_function_space.py +148 -267
  122. warp/fem/space/grid_3d_function_space.py +167 -306
  123. warp/fem/space/hexmesh_function_space.py +253 -352
  124. warp/fem/space/nanogrid_function_space.py +202 -0
  125. warp/fem/space/partition.py +350 -350
  126. warp/fem/space/quadmesh_2d_function_space.py +261 -369
  127. warp/fem/space/restriction.py +161 -160
  128. warp/fem/space/shape/__init__.py +90 -15
  129. warp/fem/space/shape/cube_shape_function.py +728 -738
  130. warp/fem/space/shape/shape_function.py +102 -103
  131. warp/fem/space/shape/square_shape_function.py +611 -611
  132. warp/fem/space/shape/tet_shape_function.py +565 -567
  133. warp/fem/space/shape/triangle_shape_function.py +429 -429
  134. warp/fem/space/tetmesh_function_space.py +224 -292
  135. warp/fem/space/topology.py +297 -295
  136. warp/fem/space/trimesh_2d_function_space.py +153 -221
  137. warp/fem/types.py +77 -77
  138. warp/fem/utils.py +495 -495
  139. warp/jax.py +166 -141
  140. warp/jax_experimental.py +341 -339
  141. warp/native/array.h +1081 -1025
  142. warp/native/builtin.h +1603 -1560
  143. warp/native/bvh.cpp +402 -398
  144. warp/native/bvh.cu +533 -525
  145. warp/native/bvh.h +430 -429
  146. warp/native/clang/clang.cpp +496 -464
  147. warp/native/crt.cpp +42 -32
  148. warp/native/crt.h +352 -335
  149. warp/native/cuda_crt.h +1049 -1049
  150. warp/native/cuda_util.cpp +549 -540
  151. warp/native/cuda_util.h +288 -203
  152. warp/native/cutlass_gemm.cpp +34 -34
  153. warp/native/cutlass_gemm.cu +372 -372
  154. warp/native/error.cpp +66 -66
  155. warp/native/error.h +27 -27
  156. warp/native/exports.h +187 -0
  157. warp/native/fabric.h +228 -228
  158. warp/native/hashgrid.cpp +301 -278
  159. warp/native/hashgrid.cu +78 -77
  160. warp/native/hashgrid.h +227 -227
  161. warp/native/initializer_array.h +32 -32
  162. warp/native/intersect.h +1204 -1204
  163. warp/native/intersect_adj.h +365 -365
  164. warp/native/intersect_tri.h +322 -322
  165. warp/native/marching.cpp +2 -2
  166. warp/native/marching.cu +497 -497
  167. warp/native/marching.h +2 -2
  168. warp/native/mat.h +1545 -1498
  169. warp/native/matnn.h +333 -333
  170. warp/native/mesh.cpp +203 -203
  171. warp/native/mesh.cu +292 -293
  172. warp/native/mesh.h +1887 -1887
  173. warp/native/nanovdb/GridHandle.h +366 -0
  174. warp/native/nanovdb/HostBuffer.h +590 -0
  175. warp/native/nanovdb/NanoVDB.h +6624 -4782
  176. warp/native/nanovdb/PNanoVDB.h +3390 -2553
  177. warp/native/noise.h +850 -850
  178. warp/native/quat.h +1112 -1085
  179. warp/native/rand.h +303 -299
  180. warp/native/range.h +108 -108
  181. warp/native/reduce.cpp +156 -156
  182. warp/native/reduce.cu +348 -348
  183. warp/native/runlength_encode.cpp +61 -61
  184. warp/native/runlength_encode.cu +46 -46
  185. warp/native/scan.cpp +30 -30
  186. warp/native/scan.cu +36 -36
  187. warp/native/scan.h +7 -7
  188. warp/native/solid_angle.h +442 -442
  189. warp/native/sort.cpp +94 -94
  190. warp/native/sort.cu +97 -97
  191. warp/native/sort.h +14 -14
  192. warp/native/sparse.cpp +337 -337
  193. warp/native/sparse.cu +544 -544
  194. warp/native/spatial.h +630 -630
  195. warp/native/svd.h +562 -562
  196. warp/native/temp_buffer.h +30 -30
  197. warp/native/vec.h +1177 -1133
  198. warp/native/volume.cpp +529 -297
  199. warp/native/volume.cu +58 -32
  200. warp/native/volume.h +960 -538
  201. warp/native/volume_builder.cu +446 -425
  202. warp/native/volume_builder.h +34 -19
  203. warp/native/volume_impl.h +61 -0
  204. warp/native/warp.cpp +1057 -1052
  205. warp/native/warp.cu +2949 -2828
  206. warp/native/warp.h +321 -305
  207. warp/optim/__init__.py +9 -9
  208. warp/optim/adam.py +120 -120
  209. warp/optim/linear.py +1104 -939
  210. warp/optim/sgd.py +104 -92
  211. warp/render/__init__.py +10 -10
  212. warp/render/render_opengl.py +3356 -3204
  213. warp/render/render_usd.py +768 -749
  214. warp/render/utils.py +152 -150
  215. warp/sim/__init__.py +52 -59
  216. warp/sim/articulation.py +685 -685
  217. warp/sim/collide.py +1594 -1590
  218. warp/sim/import_mjcf.py +489 -481
  219. warp/sim/import_snu.py +220 -221
  220. warp/sim/import_urdf.py +536 -516
  221. warp/sim/import_usd.py +887 -881
  222. warp/sim/inertia.py +316 -317
  223. warp/sim/integrator.py +234 -233
  224. warp/sim/integrator_euler.py +1956 -1956
  225. warp/sim/integrator_featherstone.py +1917 -1991
  226. warp/sim/integrator_xpbd.py +3288 -3312
  227. warp/sim/model.py +4473 -4314
  228. warp/sim/particles.py +113 -112
  229. warp/sim/render.py +417 -403
  230. warp/sim/utils.py +413 -410
  231. warp/sparse.py +1289 -1227
  232. warp/stubs.py +2192 -2469
  233. warp/tape.py +1162 -225
  234. warp/tests/__init__.py +1 -1
  235. warp/tests/__main__.py +4 -4
  236. warp/tests/assets/test_index_grid.nvdb +0 -0
  237. warp/tests/assets/torus.usda +105 -105
  238. warp/tests/aux_test_class_kernel.py +26 -26
  239. warp/tests/aux_test_compile_consts_dummy.py +10 -10
  240. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
  241. warp/tests/aux_test_dependent.py +20 -22
  242. warp/tests/aux_test_grad_customs.py +21 -23
  243. warp/tests/aux_test_reference.py +9 -11
  244. warp/tests/aux_test_reference_reference.py +8 -10
  245. warp/tests/aux_test_square.py +15 -17
  246. warp/tests/aux_test_unresolved_func.py +14 -14
  247. warp/tests/aux_test_unresolved_symbol.py +14 -14
  248. warp/tests/disabled_kinematics.py +237 -239
  249. warp/tests/run_coverage_serial.py +31 -31
  250. warp/tests/test_adam.py +155 -157
  251. warp/tests/test_arithmetic.py +1088 -1124
  252. warp/tests/test_array.py +2415 -2326
  253. warp/tests/test_array_reduce.py +148 -150
  254. warp/tests/test_async.py +666 -656
  255. warp/tests/test_atomic.py +139 -141
  256. warp/tests/test_bool.py +212 -149
  257. warp/tests/test_builtins_resolution.py +1290 -1292
  258. warp/tests/test_bvh.py +162 -171
  259. warp/tests/test_closest_point_edge_edge.py +227 -228
  260. warp/tests/test_codegen.py +562 -553
  261. warp/tests/test_compile_consts.py +217 -101
  262. warp/tests/test_conditional.py +244 -246
  263. warp/tests/test_copy.py +230 -215
  264. warp/tests/test_ctypes.py +630 -632
  265. warp/tests/test_dense.py +65 -67
  266. warp/tests/test_devices.py +89 -98
  267. warp/tests/test_dlpack.py +528 -529
  268. warp/tests/test_examples.py +403 -378
  269. warp/tests/test_fabricarray.py +952 -955
  270. warp/tests/test_fast_math.py +60 -54
  271. warp/tests/test_fem.py +1298 -1278
  272. warp/tests/test_fp16.py +128 -130
  273. warp/tests/test_func.py +336 -337
  274. warp/tests/test_generics.py +596 -571
  275. warp/tests/test_grad.py +885 -640
  276. warp/tests/test_grad_customs.py +331 -336
  277. warp/tests/test_hash_grid.py +208 -164
  278. warp/tests/test_import.py +37 -39
  279. warp/tests/test_indexedarray.py +1132 -1134
  280. warp/tests/test_intersect.py +65 -67
  281. warp/tests/test_jax.py +305 -307
  282. warp/tests/test_large.py +169 -164
  283. warp/tests/test_launch.py +352 -354
  284. warp/tests/test_lerp.py +217 -261
  285. warp/tests/test_linear_solvers.py +189 -171
  286. warp/tests/test_lvalue.py +419 -493
  287. warp/tests/test_marching_cubes.py +63 -65
  288. warp/tests/test_mat.py +1799 -1827
  289. warp/tests/test_mat_lite.py +113 -115
  290. warp/tests/test_mat_scalar_ops.py +2905 -2889
  291. warp/tests/test_math.py +124 -193
  292. warp/tests/test_matmul.py +498 -499
  293. warp/tests/test_matmul_lite.py +408 -410
  294. warp/tests/test_mempool.py +186 -190
  295. warp/tests/test_mesh.py +281 -324
  296. warp/tests/test_mesh_query_aabb.py +226 -241
  297. warp/tests/test_mesh_query_point.py +690 -702
  298. warp/tests/test_mesh_query_ray.py +290 -303
  299. warp/tests/test_mlp.py +274 -276
  300. warp/tests/test_model.py +108 -110
  301. warp/tests/test_module_hashing.py +111 -0
  302. warp/tests/test_modules_lite.py +36 -39
  303. warp/tests/test_multigpu.py +161 -163
  304. warp/tests/test_noise.py +244 -248
  305. warp/tests/test_operators.py +248 -250
  306. warp/tests/test_options.py +121 -125
  307. warp/tests/test_peer.py +131 -137
  308. warp/tests/test_pinned.py +76 -78
  309. warp/tests/test_print.py +52 -54
  310. warp/tests/test_quat.py +2084 -2086
  311. warp/tests/test_rand.py +324 -288
  312. warp/tests/test_reload.py +207 -217
  313. warp/tests/test_rounding.py +177 -179
  314. warp/tests/test_runlength_encode.py +188 -190
  315. warp/tests/test_sim_grad.py +241 -0
  316. warp/tests/test_sim_kinematics.py +89 -97
  317. warp/tests/test_smoothstep.py +166 -168
  318. warp/tests/test_snippet.py +303 -266
  319. warp/tests/test_sparse.py +466 -460
  320. warp/tests/test_spatial.py +2146 -2148
  321. warp/tests/test_special_values.py +362 -0
  322. warp/tests/test_streams.py +484 -473
  323. warp/tests/test_struct.py +708 -675
  324. warp/tests/test_tape.py +171 -148
  325. warp/tests/test_torch.py +741 -743
  326. warp/tests/test_transient_module.py +85 -87
  327. warp/tests/test_types.py +554 -659
  328. warp/tests/test_utils.py +488 -499
  329. warp/tests/test_vec.py +1262 -1268
  330. warp/tests/test_vec_lite.py +71 -73
  331. warp/tests/test_vec_scalar_ops.py +2097 -2099
  332. warp/tests/test_verify_fp.py +92 -94
  333. warp/tests/test_volume.py +961 -736
  334. warp/tests/test_volume_write.py +338 -265
  335. warp/tests/unittest_serial.py +38 -37
  336. warp/tests/unittest_suites.py +367 -359
  337. warp/tests/unittest_utils.py +434 -578
  338. warp/tests/unused_test_misc.py +69 -71
  339. warp/tests/walkthrough_debug.py +85 -85
  340. warp/thirdparty/appdirs.py +598 -598
  341. warp/thirdparty/dlpack.py +143 -143
  342. warp/thirdparty/unittest_parallel.py +563 -561
  343. warp/torch.py +321 -295
  344. warp/types.py +4941 -4450
  345. warp/utils.py +1008 -821
  346. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/LICENSE.md +126 -126
  347. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/METADATA +365 -400
  348. warp_lang-1.2.0.dist-info/RECORD +359 -0
  349. warp/examples/assets/cube.usda +0 -42
  350. warp/examples/assets/sphere.usda +0 -56
  351. warp/examples/assets/torus.usda +0 -105
  352. warp/examples/fem/example_convection_diffusion_dg0.py +0 -194
  353. warp/native/nanovdb/PNanoVDBWrite.h +0 -295
  354. warp_lang-1.0.2.dist-info/RECORD +0 -352
  355. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/WHEEL +0 -0
  356. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/top_level.txt +0 -0
@@ -1,2099 +1,2097 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
-
8
- import unittest
9
-
10
- import numpy as np
11
-
12
- import warp as wp
13
- from warp.tests.unittest_utils import *
14
-
15
- wp.init()
16
-
17
- np_signed_int_types = [
18
- np.int8,
19
- np.int16,
20
- np.int32,
21
- np.int64,
22
- np.byte,
23
- ]
24
-
25
- np_unsigned_int_types = [
26
- np.uint8,
27
- np.uint16,
28
- np.uint32,
29
- np.uint64,
30
- np.ubyte,
31
- ]
32
-
33
- np_int_types = np_signed_int_types + np_unsigned_int_types
34
-
35
- np_float_types = [np.float16, np.float32, np.float64]
36
-
37
- np_scalar_types = np_int_types + np_float_types
38
-
39
-
40
- def randvals(rng, shape, dtype):
41
- if dtype in np_float_types:
42
- return rng.standard_normal(size=shape).astype(dtype)
43
- elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
44
- return rng.integers(1, high=3, size=shape, dtype=dtype)
45
- return rng.integers(1, high=5, size=shape, dtype=dtype)
46
-
47
-
48
- kernel_cache = dict()
49
-
50
-
51
- def getkernel(func, suffix=""):
52
- key = func.__name__ + "_" + suffix
53
- if key not in kernel_cache:
54
- kernel_cache[key] = wp.Kernel(func=func, key=key)
55
- return kernel_cache[key]
56
-
57
-
58
- def get_select_kernel(dtype):
59
- def output_select_kernel_fn(
60
- input: wp.array(dtype=dtype),
61
- index: int,
62
- out: wp.array(dtype=dtype),
63
- ):
64
- out[0] = input[index]
65
-
66
- return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
67
-
68
-
69
- def get_select_kernel2(dtype):
70
- def output_select_kernel2_fn(
71
- input: wp.array(dtype=dtype, ndim=2),
72
- index0: int,
73
- index1: int,
74
- out: wp.array(dtype=dtype),
75
- ):
76
- out[0] = input[index0, index1]
77
-
78
- return getkernel(output_select_kernel2_fn, suffix=dtype.__name__)
79
-
80
-
81
- def test_arrays(test, device, dtype):
82
- rng = np.random.default_rng(123)
83
-
84
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
85
- vec2 = wp.types.vector(length=2, dtype=wptype)
86
- vec3 = wp.types.vector(length=3, dtype=wptype)
87
- vec4 = wp.types.vector(length=4, dtype=wptype)
88
- vec5 = wp.types.vector(length=5, dtype=wptype)
89
-
90
- v2_np = randvals(rng, (10, 2), dtype)
91
- v3_np = randvals(rng, (10, 3), dtype)
92
- v4_np = randvals(rng, (10, 4), dtype)
93
- v5_np = randvals(rng, (10, 5), dtype)
94
-
95
- v2 = wp.array(v2_np, dtype=vec2, requires_grad=True, device=device)
96
- v3 = wp.array(v3_np, dtype=vec3, requires_grad=True, device=device)
97
- v4 = wp.array(v4_np, dtype=vec4, requires_grad=True, device=device)
98
- v5 = wp.array(v5_np, dtype=vec5, requires_grad=True, device=device)
99
-
100
- assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
101
- assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
102
- assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
103
- assert_np_equal(v5.numpy(), v5_np, tol=1.0e-6)
104
-
105
- vec2 = wp.types.vector(length=2, dtype=wptype)
106
- vec3 = wp.types.vector(length=3, dtype=wptype)
107
- vec4 = wp.types.vector(length=4, dtype=wptype)
108
-
109
- v2 = wp.array(v2_np, dtype=vec2, requires_grad=True, device=device)
110
- v3 = wp.array(v3_np, dtype=vec3, requires_grad=True, device=device)
111
- v4 = wp.array(v4_np, dtype=vec4, requires_grad=True, device=device)
112
-
113
- assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
114
- assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
115
- assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
116
-
117
-
118
- def test_components(test, device, dtype):
119
- # test accessing vector components from Python - this is especially important
120
- # for float16, which requires special handling internally
121
-
122
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
123
- vec3 = wp.types.vector(length=3, dtype=wptype)
124
-
125
- v = vec3(1, 2, 3)
126
-
127
- # test __getitem__ for individual components
128
- test.assertEqual(v[0], 1)
129
- test.assertEqual(v[1], 2)
130
- test.assertEqual(v[2], 3)
131
-
132
- # test __getitem__ for slices
133
- s = v[:]
134
- test.assertEqual(s[0], 1)
135
- test.assertEqual(s[1], 2)
136
- test.assertEqual(s[2], 3)
137
-
138
- s = v[1:]
139
- test.assertEqual(s[0], 2)
140
- test.assertEqual(s[1], 3)
141
-
142
- s = v[:2]
143
- test.assertEqual(s[0], 1)
144
- test.assertEqual(s[1], 2)
145
-
146
- s = v[::2]
147
- test.assertEqual(s[0], 1)
148
- test.assertEqual(s[1], 3)
149
-
150
- # test __setitem__ for individual components
151
- v[0] = 4
152
- v[1] = 5
153
- v[2] = 6
154
- test.assertEqual(v[0], 4)
155
- test.assertEqual(v[1], 5)
156
- test.assertEqual(v[2], 6)
157
-
158
- # test __setitem__ for slices
159
- v[:] = [7, 8, 9]
160
- test.assertEqual(v[0], 7)
161
- test.assertEqual(v[1], 8)
162
- test.assertEqual(v[2], 9)
163
-
164
- v[1:] = [10, 11]
165
- test.assertEqual(v[0], 7)
166
- test.assertEqual(v[1], 10)
167
- test.assertEqual(v[2], 11)
168
-
169
- v[:2] = [12, 13]
170
- test.assertEqual(v[0], 12)
171
- test.assertEqual(v[1], 13)
172
- test.assertEqual(v[2], 11)
173
-
174
- v[::2] = [14, 15]
175
- test.assertEqual(v[0], 14)
176
- test.assertEqual(v[1], 13)
177
- test.assertEqual(v[2], 15)
178
-
179
-
180
- def test_py_arithmetic_ops(test, device, dtype):
181
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
182
-
183
- def make_vec(*args):
184
- if wptype in wp.types.int_types:
185
- # Cast to the correct integer type to simulate wrapping.
186
- return tuple(wptype._type_(x).value for x in args)
187
-
188
- return args
189
-
190
- vec_cls = wp.vec(3, wptype)
191
-
192
- v = vec_cls(1, -2, 3)
193
- test.assertSequenceEqual(+v, make_vec(1, -2, 3))
194
- test.assertSequenceEqual(-v, make_vec(-1, 2, -3))
195
- test.assertSequenceEqual(v + vec_cls(5, 5, 5), make_vec(6, 3, 8))
196
- test.assertSequenceEqual(v - vec_cls(5, 5, 5), make_vec(-4, -7, -2))
197
-
198
- v = vec_cls(2, 4, 6)
199
- test.assertSequenceEqual(v * wptype(2), make_vec(4, 8, 12))
200
- test.assertSequenceEqual(wptype(2) * v, make_vec(4, 8, 12))
201
- test.assertSequenceEqual(v / wptype(2), make_vec(1, 2, 3))
202
- test.assertSequenceEqual(wptype(24) / v, make_vec(12, 6, 4))
203
-
204
-
205
- def test_constructors(test, device, dtype, register_kernels=False):
206
- rng = np.random.default_rng(123)
207
-
208
- tol = {
209
- np.float16: 5.0e-3,
210
- np.float32: 1.0e-6,
211
- np.float64: 1.0e-8,
212
- }.get(dtype, 0)
213
-
214
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
215
- vec2 = wp.types.vector(length=2, dtype=wptype)
216
- vec3 = wp.types.vector(length=3, dtype=wptype)
217
- vec4 = wp.types.vector(length=4, dtype=wptype)
218
- vec5 = wp.types.vector(length=5, dtype=wptype)
219
-
220
- def check_scalar_constructor(
221
- input: wp.array(dtype=wptype),
222
- v2: wp.array(dtype=vec2),
223
- v3: wp.array(dtype=vec3),
224
- v4: wp.array(dtype=vec4),
225
- v5: wp.array(dtype=vec5),
226
- v20: wp.array(dtype=wptype),
227
- v21: wp.array(dtype=wptype),
228
- v30: wp.array(dtype=wptype),
229
- v31: wp.array(dtype=wptype),
230
- v32: wp.array(dtype=wptype),
231
- v40: wp.array(dtype=wptype),
232
- v41: wp.array(dtype=wptype),
233
- v42: wp.array(dtype=wptype),
234
- v43: wp.array(dtype=wptype),
235
- v50: wp.array(dtype=wptype),
236
- v51: wp.array(dtype=wptype),
237
- v52: wp.array(dtype=wptype),
238
- v53: wp.array(dtype=wptype),
239
- v54: wp.array(dtype=wptype),
240
- ):
241
- v2result = vec2(input[0])
242
- v3result = vec3(input[0])
243
- v4result = vec4(input[0])
244
- v5result = vec5(input[0])
245
-
246
- v2[0] = v2result
247
- v3[0] = v3result
248
- v4[0] = v4result
249
- v5[0] = v5result
250
-
251
- # multiply outputs by 2 so we've got something to backpropagate
252
- v20[0] = wptype(2) * v2result[0]
253
- v21[0] = wptype(2) * v2result[1]
254
-
255
- v30[0] = wptype(2) * v3result[0]
256
- v31[0] = wptype(2) * v3result[1]
257
- v32[0] = wptype(2) * v3result[2]
258
-
259
- v40[0] = wptype(2) * v4result[0]
260
- v41[0] = wptype(2) * v4result[1]
261
- v42[0] = wptype(2) * v4result[2]
262
- v43[0] = wptype(2) * v4result[3]
263
-
264
- v50[0] = wptype(2) * v5result[0]
265
- v51[0] = wptype(2) * v5result[1]
266
- v52[0] = wptype(2) * v5result[2]
267
- v53[0] = wptype(2) * v5result[3]
268
- v54[0] = wptype(2) * v5result[4]
269
-
270
- def check_vector_constructors(
271
- input: wp.array(dtype=wptype),
272
- v2: wp.array(dtype=vec2),
273
- v3: wp.array(dtype=vec3),
274
- v4: wp.array(dtype=vec4),
275
- v5: wp.array(dtype=vec5),
276
- v20: wp.array(dtype=wptype),
277
- v21: wp.array(dtype=wptype),
278
- v30: wp.array(dtype=wptype),
279
- v31: wp.array(dtype=wptype),
280
- v32: wp.array(dtype=wptype),
281
- v40: wp.array(dtype=wptype),
282
- v41: wp.array(dtype=wptype),
283
- v42: wp.array(dtype=wptype),
284
- v43: wp.array(dtype=wptype),
285
- v50: wp.array(dtype=wptype),
286
- v51: wp.array(dtype=wptype),
287
- v52: wp.array(dtype=wptype),
288
- v53: wp.array(dtype=wptype),
289
- v54: wp.array(dtype=wptype),
290
- ):
291
- v2result = vec2(input[0], input[1])
292
- v3result = vec3(input[2], input[3], input[4])
293
- v4result = vec4(input[5], input[6], input[7], input[8])
294
- v5result = vec5(input[9], input[10], input[11], input[12], input[13])
295
-
296
- v2[0] = v2result
297
- v3[0] = v3result
298
- v4[0] = v4result
299
- v5[0] = v5result
300
-
301
- # multiply the output by 2 so we've got something to backpropagate:
302
- v20[0] = wptype(2) * v2result[0]
303
- v21[0] = wptype(2) * v2result[1]
304
-
305
- v30[0] = wptype(2) * v3result[0]
306
- v31[0] = wptype(2) * v3result[1]
307
- v32[0] = wptype(2) * v3result[2]
308
-
309
- v40[0] = wptype(2) * v4result[0]
310
- v41[0] = wptype(2) * v4result[1]
311
- v42[0] = wptype(2) * v4result[2]
312
- v43[0] = wptype(2) * v4result[3]
313
-
314
- v50[0] = wptype(2) * v5result[0]
315
- v51[0] = wptype(2) * v5result[1]
316
- v52[0] = wptype(2) * v5result[2]
317
- v53[0] = wptype(2) * v5result[3]
318
- v54[0] = wptype(2) * v5result[4]
319
-
320
- vec_kernel = getkernel(check_vector_constructors, suffix=dtype.__name__)
321
- kernel = getkernel(check_scalar_constructor, suffix=dtype.__name__)
322
-
323
- if register_kernels:
324
- return
325
-
326
- input = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
327
- v2 = wp.zeros(1, dtype=vec2, device=device)
328
- v3 = wp.zeros(1, dtype=vec3, device=device)
329
- v4 = wp.zeros(1, dtype=vec4, device=device)
330
- v5 = wp.zeros(1, dtype=vec5, device=device)
331
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
332
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
333
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
334
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
335
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
336
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
337
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
338
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
339
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
340
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
341
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
342
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
343
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
344
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
345
-
346
- tape = wp.Tape()
347
- with tape:
348
- wp.launch(
349
- kernel,
350
- dim=1,
351
- inputs=[input],
352
- outputs=[v2, v3, v4, v5, v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
353
- device=device,
354
- )
355
-
356
- if dtype in np_float_types:
357
- for l in [v20, v21]:
358
- tape.backward(loss=l)
359
- test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
360
- tape.zero()
361
-
362
- for l in [v30, v31, v32]:
363
- tape.backward(loss=l)
364
- test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
365
- tape.zero()
366
-
367
- for l in [v40, v41, v42, v43]:
368
- tape.backward(loss=l)
369
- test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
370
- tape.zero()
371
-
372
- for l in [v50, v51, v52, v53, v54]:
373
- tape.backward(loss=l)
374
- test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
375
- tape.zero()
376
-
377
- val = input.numpy()[0]
378
- assert_np_equal(v2.numpy()[0], np.array([val, val]), tol=1.0e-6)
379
- assert_np_equal(v3.numpy()[0], np.array([val, val, val]), tol=1.0e-6)
380
- assert_np_equal(v4.numpy()[0], np.array([val, val, val, val]), tol=1.0e-6)
381
- assert_np_equal(v5.numpy()[0], np.array([val, val, val, val, val]), tol=1.0e-6)
382
-
383
- assert_np_equal(v20.numpy()[0], 2 * val, tol=1.0e-6)
384
- assert_np_equal(v21.numpy()[0], 2 * val, tol=1.0e-6)
385
- assert_np_equal(v30.numpy()[0], 2 * val, tol=1.0e-6)
386
- assert_np_equal(v31.numpy()[0], 2 * val, tol=1.0e-6)
387
- assert_np_equal(v32.numpy()[0], 2 * val, tol=1.0e-6)
388
- assert_np_equal(v40.numpy()[0], 2 * val, tol=1.0e-6)
389
- assert_np_equal(v41.numpy()[0], 2 * val, tol=1.0e-6)
390
- assert_np_equal(v42.numpy()[0], 2 * val, tol=1.0e-6)
391
- assert_np_equal(v43.numpy()[0], 2 * val, tol=1.0e-6)
392
- assert_np_equal(v50.numpy()[0], 2 * val, tol=1.0e-6)
393
- assert_np_equal(v51.numpy()[0], 2 * val, tol=1.0e-6)
394
- assert_np_equal(v52.numpy()[0], 2 * val, tol=1.0e-6)
395
- assert_np_equal(v53.numpy()[0], 2 * val, tol=1.0e-6)
396
- assert_np_equal(v54.numpy()[0], 2 * val, tol=1.0e-6)
397
-
398
- input = wp.array(randvals(rng, [14], dtype), requires_grad=True, device=device)
399
- tape = wp.Tape()
400
- with tape:
401
- wp.launch(
402
- vec_kernel,
403
- dim=1,
404
- inputs=[input],
405
- outputs=[v2, v3, v4, v5, v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
406
- device=device,
407
- )
408
-
409
- if dtype in np_float_types:
410
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
411
- tape.backward(loss=l)
412
- grad = tape.gradients[input].numpy()
413
- expected_grad = np.zeros_like(grad)
414
- expected_grad[i] = 2
415
- assert_np_equal(grad, expected_grad, tol=tol)
416
- tape.zero()
417
-
418
- assert_np_equal(v2.numpy()[0, 0], input.numpy()[0], tol=tol)
419
- assert_np_equal(v2.numpy()[0, 1], input.numpy()[1], tol=tol)
420
- assert_np_equal(v3.numpy()[0, 0], input.numpy()[2], tol=tol)
421
- assert_np_equal(v3.numpy()[0, 1], input.numpy()[3], tol=tol)
422
- assert_np_equal(v3.numpy()[0, 2], input.numpy()[4], tol=tol)
423
- assert_np_equal(v4.numpy()[0, 0], input.numpy()[5], tol=tol)
424
- assert_np_equal(v4.numpy()[0, 1], input.numpy()[6], tol=tol)
425
- assert_np_equal(v4.numpy()[0, 2], input.numpy()[7], tol=tol)
426
- assert_np_equal(v4.numpy()[0, 3], input.numpy()[8], tol=tol)
427
- assert_np_equal(v5.numpy()[0, 0], input.numpy()[9], tol=tol)
428
- assert_np_equal(v5.numpy()[0, 1], input.numpy()[10], tol=tol)
429
- assert_np_equal(v5.numpy()[0, 2], input.numpy()[11], tol=tol)
430
- assert_np_equal(v5.numpy()[0, 3], input.numpy()[12], tol=tol)
431
- assert_np_equal(v5.numpy()[0, 4], input.numpy()[13], tol=tol)
432
-
433
- assert_np_equal(v20.numpy()[0], 2 * input.numpy()[0], tol=tol)
434
- assert_np_equal(v21.numpy()[0], 2 * input.numpy()[1], tol=tol)
435
- assert_np_equal(v30.numpy()[0], 2 * input.numpy()[2], tol=tol)
436
- assert_np_equal(v31.numpy()[0], 2 * input.numpy()[3], tol=tol)
437
- assert_np_equal(v32.numpy()[0], 2 * input.numpy()[4], tol=tol)
438
- assert_np_equal(v40.numpy()[0], 2 * input.numpy()[5], tol=tol)
439
- assert_np_equal(v41.numpy()[0], 2 * input.numpy()[6], tol=tol)
440
- assert_np_equal(v42.numpy()[0], 2 * input.numpy()[7], tol=tol)
441
- assert_np_equal(v43.numpy()[0], 2 * input.numpy()[8], tol=tol)
442
- assert_np_equal(v50.numpy()[0], 2 * input.numpy()[9], tol=tol)
443
- assert_np_equal(v51.numpy()[0], 2 * input.numpy()[10], tol=tol)
444
- assert_np_equal(v52.numpy()[0], 2 * input.numpy()[11], tol=tol)
445
- assert_np_equal(v53.numpy()[0], 2 * input.numpy()[12], tol=tol)
446
- assert_np_equal(v54.numpy()[0], 2 * input.numpy()[13], tol=tol)
447
-
448
-
449
- def test_anon_type_instance(test, device, dtype, register_kernels=False):
450
- rng = np.random.default_rng(123)
451
-
452
- tol = {
453
- np.float16: 5.0e-3,
454
- np.float32: 1.0e-6,
455
- np.float64: 1.0e-8,
456
- }.get(dtype, 0)
457
-
458
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
459
-
460
- def check_scalar_init(
461
- input: wp.array(dtype=wptype),
462
- output: wp.array(dtype=wptype),
463
- ):
464
- v2result = wp.vector(input[0], length=2)
465
- v3result = wp.vector(input[1], length=3)
466
- v4result = wp.vector(input[2], length=4)
467
- v5result = wp.vector(input[3], length=5)
468
-
469
- idx = 0
470
- for i in range(2):
471
- output[idx] = wptype(2) * v2result[i]
472
- idx = idx + 1
473
- for i in range(3):
474
- output[idx] = wptype(2) * v3result[i]
475
- idx = idx + 1
476
- for i in range(4):
477
- output[idx] = wptype(2) * v4result[i]
478
- idx = idx + 1
479
- for i in range(5):
480
- output[idx] = wptype(2) * v5result[i]
481
- idx = idx + 1
482
-
483
- def check_component_init(
484
- input: wp.array(dtype=wptype),
485
- output: wp.array(dtype=wptype),
486
- ):
487
- v2result = wp.vector(input[0], input[1])
488
- v3result = wp.vector(input[2], input[3], input[4])
489
- v4result = wp.vector(input[5], input[6], input[7], input[8])
490
- v5result = wp.vector(input[9], input[10], input[11], input[12], input[13])
491
-
492
- idx = 0
493
- for i in range(2):
494
- output[idx] = wptype(2) * v2result[i]
495
- idx = idx + 1
496
- for i in range(3):
497
- output[idx] = wptype(2) * v3result[i]
498
- idx = idx + 1
499
- for i in range(4):
500
- output[idx] = wptype(2) * v4result[i]
501
- idx = idx + 1
502
- for i in range(5):
503
- output[idx] = wptype(2) * v5result[i]
504
- idx = idx + 1
505
-
506
- scalar_kernel = getkernel(check_scalar_init, suffix=dtype.__name__)
507
- component_kernel = getkernel(check_component_init, suffix=dtype.__name__)
508
- output_select_kernel = get_select_kernel(wptype)
509
-
510
- if register_kernels:
511
- return
512
-
513
- input = wp.array(randvals(rng, [4], dtype), requires_grad=True, device=device)
514
- output = wp.zeros(2 + 3 + 4 + 5, dtype=wptype, requires_grad=True, device=device)
515
-
516
- wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
517
-
518
- assert_np_equal(output.numpy()[:2], 2 * np.array([input.numpy()[0]] * 2), tol=1.0e-6)
519
- assert_np_equal(output.numpy()[2:5], 2 * np.array([input.numpy()[1]] * 3), tol=1.0e-6)
520
- assert_np_equal(output.numpy()[5:9], 2 * np.array([input.numpy()[2]] * 4), tol=1.0e-6)
521
- assert_np_equal(output.numpy()[9:], 2 * np.array([input.numpy()[3]] * 5), tol=1.0e-6)
522
-
523
- if dtype in np_float_types:
524
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
525
- for i in range(len(output)):
526
- tape = wp.Tape()
527
- with tape:
528
- wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
529
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
530
-
531
- tape.backward(loss=out)
532
- expected = np.zeros_like(input.numpy())
533
- if i < 2:
534
- expected[0] = 2
535
- elif i < 5:
536
- expected[1] = 2
537
- elif i < 9:
538
- expected[2] = 2
539
- else:
540
- expected[3] = 2
541
-
542
- assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
543
-
544
- tape.reset()
545
- tape.zero()
546
-
547
- input = wp.array(randvals(rng, [2 + 3 + 4 + 5], dtype), requires_grad=True, device=device)
548
- output = wp.zeros(2 + 3 + 4 + 5, dtype=wptype, requires_grad=True, device=device)
549
-
550
- wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
551
-
552
- assert_np_equal(output.numpy(), 2 * input.numpy(), tol=1.0e-6)
553
-
554
- if dtype in np_float_types:
555
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
556
- for i in range(len(output)):
557
- tape = wp.Tape()
558
- with tape:
559
- wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
560
- wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
561
-
562
- tape.backward(loss=out)
563
- expected = np.zeros_like(input.numpy())
564
- expected[i] = 2
565
-
566
- assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
567
-
568
- tape.reset()
569
- tape.zero()
570
-
571
-
572
- def test_indexing(test, device, dtype, register_kernels=False):
573
- rng = np.random.default_rng(123)
574
-
575
- tol = {
576
- np.float16: 5.0e-3,
577
- np.float32: 1.0e-6,
578
- np.float64: 1.0e-8,
579
- }.get(dtype, 0)
580
-
581
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
582
- vec2 = wp.types.vector(length=2, dtype=wptype)
583
- vec3 = wp.types.vector(length=3, dtype=wptype)
584
- vec4 = wp.types.vector(length=4, dtype=wptype)
585
- vec5 = wp.types.vector(length=5, dtype=wptype)
586
-
587
- def check_indexing(
588
- v2: wp.array(dtype=vec2),
589
- v3: wp.array(dtype=vec3),
590
- v4: wp.array(dtype=vec4),
591
- v5: wp.array(dtype=vec5),
592
- v20: wp.array(dtype=wptype),
593
- v21: wp.array(dtype=wptype),
594
- v30: wp.array(dtype=wptype),
595
- v31: wp.array(dtype=wptype),
596
- v32: wp.array(dtype=wptype),
597
- v40: wp.array(dtype=wptype),
598
- v41: wp.array(dtype=wptype),
599
- v42: wp.array(dtype=wptype),
600
- v43: wp.array(dtype=wptype),
601
- v50: wp.array(dtype=wptype),
602
- v51: wp.array(dtype=wptype),
603
- v52: wp.array(dtype=wptype),
604
- v53: wp.array(dtype=wptype),
605
- v54: wp.array(dtype=wptype),
606
- ):
607
- # multiply outputs by 2 so we've got something to backpropagate:
608
- v20[0] = wptype(2) * v2[0][0]
609
- v21[0] = wptype(2) * v2[0][1]
610
-
611
- v30[0] = wptype(2) * v3[0][0]
612
- v31[0] = wptype(2) * v3[0][1]
613
- v32[0] = wptype(2) * v3[0][2]
614
-
615
- v40[0] = wptype(2) * v4[0][0]
616
- v41[0] = wptype(2) * v4[0][1]
617
- v42[0] = wptype(2) * v4[0][2]
618
- v43[0] = wptype(2) * v4[0][3]
619
-
620
- v50[0] = wptype(2) * v5[0][0]
621
- v51[0] = wptype(2) * v5[0][1]
622
- v52[0] = wptype(2) * v5[0][2]
623
- v53[0] = wptype(2) * v5[0][3]
624
- v54[0] = wptype(2) * v5[0][4]
625
-
626
- kernel = getkernel(check_indexing, suffix=dtype.__name__)
627
-
628
- if register_kernels:
629
- return
630
-
631
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
632
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
633
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
634
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
635
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
636
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
637
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
638
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
639
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
640
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
641
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
642
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
643
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
644
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
645
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
646
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
647
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
648
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
649
-
650
- tape = wp.Tape()
651
- with tape:
652
- wp.launch(
653
- kernel,
654
- dim=1,
655
- inputs=[v2, v3, v4, v5],
656
- outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
657
- device=device,
658
- )
659
-
660
- if dtype in np_float_types:
661
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
662
- tape.backward(loss=l)
663
- allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
664
- expected_grads = np.zeros_like(allgrads)
665
- expected_grads[i] = 2
666
- assert_np_equal(allgrads, expected_grads, tol=tol)
667
- tape.zero()
668
-
669
- assert_np_equal(v20.numpy()[0], 2.0 * v2.numpy()[0, 0], tol=tol)
670
- assert_np_equal(v21.numpy()[0], 2.0 * v2.numpy()[0, 1], tol=tol)
671
- assert_np_equal(v30.numpy()[0], 2.0 * v3.numpy()[0, 0], tol=tol)
672
- assert_np_equal(v31.numpy()[0], 2.0 * v3.numpy()[0, 1], tol=tol)
673
- assert_np_equal(v32.numpy()[0], 2.0 * v3.numpy()[0, 2], tol=tol)
674
- assert_np_equal(v40.numpy()[0], 2.0 * v4.numpy()[0, 0], tol=tol)
675
- assert_np_equal(v41.numpy()[0], 2.0 * v4.numpy()[0, 1], tol=tol)
676
- assert_np_equal(v42.numpy()[0], 2.0 * v4.numpy()[0, 2], tol=tol)
677
- assert_np_equal(v43.numpy()[0], 2.0 * v4.numpy()[0, 3], tol=tol)
678
- assert_np_equal(v50.numpy()[0], 2.0 * v5.numpy()[0, 0], tol=tol)
679
- assert_np_equal(v51.numpy()[0], 2.0 * v5.numpy()[0, 1], tol=tol)
680
- assert_np_equal(v52.numpy()[0], 2.0 * v5.numpy()[0, 2], tol=tol)
681
- assert_np_equal(v53.numpy()[0], 2.0 * v5.numpy()[0, 3], tol=tol)
682
- assert_np_equal(v54.numpy()[0], 2.0 * v5.numpy()[0, 4], tol=tol)
683
-
684
-
685
- def test_equality(test, device, dtype, register_kernels=False):
686
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
687
- vec2 = wp.types.vector(length=2, dtype=wptype)
688
- vec3 = wp.types.vector(length=3, dtype=wptype)
689
- vec4 = wp.types.vector(length=4, dtype=wptype)
690
- vec5 = wp.types.vector(length=5, dtype=wptype)
691
-
692
- def check_equality(
693
- v20: wp.array(dtype=vec2),
694
- v21: wp.array(dtype=vec2),
695
- v22: wp.array(dtype=vec2),
696
- v30: wp.array(dtype=vec3),
697
- v31: wp.array(dtype=vec3),
698
- v32: wp.array(dtype=vec3),
699
- v33: wp.array(dtype=vec3),
700
- v40: wp.array(dtype=vec4),
701
- v41: wp.array(dtype=vec4),
702
- v42: wp.array(dtype=vec4),
703
- v43: wp.array(dtype=vec4),
704
- v44: wp.array(dtype=vec4),
705
- v50: wp.array(dtype=vec5),
706
- v51: wp.array(dtype=vec5),
707
- v52: wp.array(dtype=vec5),
708
- v53: wp.array(dtype=vec5),
709
- v54: wp.array(dtype=vec5),
710
- v55: wp.array(dtype=vec5),
711
- ):
712
- wp.expect_eq(v20[0], v20[0])
713
- wp.expect_neq(v21[0], v20[0])
714
- wp.expect_neq(v22[0], v20[0])
715
-
716
- wp.expect_eq(v30[0], v30[0])
717
- wp.expect_neq(v31[0], v30[0])
718
- wp.expect_neq(v32[0], v30[0])
719
- wp.expect_neq(v33[0], v30[0])
720
-
721
- wp.expect_eq(v40[0], v40[0])
722
- wp.expect_neq(v41[0], v40[0])
723
- wp.expect_neq(v42[0], v40[0])
724
- wp.expect_neq(v43[0], v40[0])
725
- wp.expect_neq(v44[0], v40[0])
726
-
727
- wp.expect_eq(v50[0], v50[0])
728
- wp.expect_neq(v51[0], v50[0])
729
- wp.expect_neq(v52[0], v50[0])
730
- wp.expect_neq(v53[0], v50[0])
731
- wp.expect_neq(v54[0], v50[0])
732
- wp.expect_neq(v55[0], v50[0])
733
-
734
- kernel = getkernel(check_equality, suffix=dtype.__name__)
735
-
736
- if register_kernels:
737
- return
738
-
739
- v20 = wp.array([1.0, 2.0], dtype=vec2, requires_grad=True, device=device)
740
- v21 = wp.array([1.0, 3.0], dtype=vec2, requires_grad=True, device=device)
741
- v22 = wp.array([3.0, 2.0], dtype=vec2, requires_grad=True, device=device)
742
-
743
- v30 = wp.array([1.0, 2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
744
- v31 = wp.array([-1.0, 2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
745
- v32 = wp.array([1.0, -2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
746
- v33 = wp.array([1.0, 2.0, -3.0], dtype=vec3, requires_grad=True, device=device)
747
-
748
- v40 = wp.array([1.0, 2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
749
- v41 = wp.array([-1.0, 2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
750
- v42 = wp.array([1.0, -2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
751
- v43 = wp.array([1.0, 2.0, -3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
752
- v44 = wp.array([1.0, 2.0, 3.0, -4.0], dtype=vec4, requires_grad=True, device=device)
753
-
754
- v50 = wp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
755
- v51 = wp.array([-1.0, 2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
756
- v52 = wp.array([1.0, -2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
757
- v53 = wp.array([1.0, 2.0, -3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
758
- v54 = wp.array([1.0, 2.0, 3.0, -4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
759
- v55 = wp.array([1.0, 2.0, 3.0, 4.0, -5.0], dtype=vec5, requires_grad=True, device=device)
760
- wp.launch(
761
- kernel,
762
- dim=1,
763
- inputs=[
764
- v20,
765
- v21,
766
- v22,
767
- v30,
768
- v31,
769
- v32,
770
- v33,
771
- v40,
772
- v41,
773
- v42,
774
- v43,
775
- v44,
776
- v50,
777
- v51,
778
- v52,
779
- v53,
780
- v54,
781
- v55,
782
- ],
783
- outputs=[],
784
- device=device,
785
- )
786
-
787
-
788
- def test_scalar_multiplication(test, device, dtype, register_kernels=False):
789
- rng = np.random.default_rng(123)
790
-
791
- tol = {
792
- np.float16: 5.0e-3,
793
- np.float32: 1.0e-6,
794
- np.float64: 1.0e-8,
795
- }.get(dtype, 0)
796
-
797
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
798
- vec2 = wp.types.vector(length=2, dtype=wptype)
799
- vec3 = wp.types.vector(length=3, dtype=wptype)
800
- vec4 = wp.types.vector(length=4, dtype=wptype)
801
- vec5 = wp.types.vector(length=5, dtype=wptype)
802
-
803
- def check_mul(
804
- s: wp.array(dtype=wptype),
805
- v2: wp.array(dtype=vec2),
806
- v3: wp.array(dtype=vec3),
807
- v4: wp.array(dtype=vec4),
808
- v5: wp.array(dtype=vec5),
809
- v20: wp.array(dtype=wptype),
810
- v21: wp.array(dtype=wptype),
811
- v30: wp.array(dtype=wptype),
812
- v31: wp.array(dtype=wptype),
813
- v32: wp.array(dtype=wptype),
814
- v40: wp.array(dtype=wptype),
815
- v41: wp.array(dtype=wptype),
816
- v42: wp.array(dtype=wptype),
817
- v43: wp.array(dtype=wptype),
818
- v50: wp.array(dtype=wptype),
819
- v51: wp.array(dtype=wptype),
820
- v52: wp.array(dtype=wptype),
821
- v53: wp.array(dtype=wptype),
822
- v54: wp.array(dtype=wptype),
823
- ):
824
- v2result = s[0] * v2[0]
825
- v3result = s[0] * v3[0]
826
- v4result = s[0] * v4[0]
827
- v5result = s[0] * v5[0]
828
-
829
- # multiply outputs by 2 so we've got something to backpropagate:
830
- v20[0] = wptype(2) * v2result[0]
831
- v21[0] = wptype(2) * v2result[1]
832
-
833
- v30[0] = wptype(2) * v3result[0]
834
- v31[0] = wptype(2) * v3result[1]
835
- v32[0] = wptype(2) * v3result[2]
836
-
837
- v40[0] = wptype(2) * v4result[0]
838
- v41[0] = wptype(2) * v4result[1]
839
- v42[0] = wptype(2) * v4result[2]
840
- v43[0] = wptype(2) * v4result[3]
841
-
842
- v50[0] = wptype(2) * v5result[0]
843
- v51[0] = wptype(2) * v5result[1]
844
- v52[0] = wptype(2) * v5result[2]
845
- v53[0] = wptype(2) * v5result[3]
846
- v54[0] = wptype(2) * v5result[4]
847
-
848
- kernel = getkernel(check_mul, suffix=dtype.__name__)
849
-
850
- if register_kernels:
851
- return
852
-
853
- s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
854
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
855
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
856
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
857
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
858
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
859
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
860
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
861
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
862
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
863
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
864
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
865
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
866
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
867
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
868
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
869
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
870
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
871
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
872
- tape = wp.Tape()
873
- with tape:
874
- wp.launch(
875
- kernel,
876
- dim=1,
877
- inputs=[
878
- s,
879
- v2,
880
- v3,
881
- v4,
882
- v5,
883
- ],
884
- outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
885
- device=device,
886
- )
887
-
888
- assert_np_equal(v20.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 0], tol=tol)
889
- assert_np_equal(v21.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 1], tol=tol)
890
-
891
- assert_np_equal(v30.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 0], tol=10 * tol)
892
- assert_np_equal(v31.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 1], tol=10 * tol)
893
- assert_np_equal(v32.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 2], tol=10 * tol)
894
-
895
- assert_np_equal(v40.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 0], tol=10 * tol)
896
- assert_np_equal(v41.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 1], tol=10 * tol)
897
- assert_np_equal(v42.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 2], tol=10 * tol)
898
- assert_np_equal(v43.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 3], tol=10 * tol)
899
-
900
- assert_np_equal(v50.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 0], tol=10 * tol)
901
- assert_np_equal(v51.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 1], tol=10 * tol)
902
- assert_np_equal(v52.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 2], tol=10 * tol)
903
- assert_np_equal(v53.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 3], tol=10 * tol)
904
- assert_np_equal(v54.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 4], tol=10 * tol)
905
-
906
- incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
907
-
908
- if dtype in np_float_types:
909
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43]):
910
- tape.backward(loss=l)
911
- sgrad = tape.gradients[s].numpy()[0]
912
- assert_np_equal(sgrad, 2 * incmps[i], tol=10 * tol)
913
- allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4]])
914
- expected_grads = np.zeros_like(allgrads)
915
- expected_grads[i] = s.numpy()[0] * 2
916
- assert_np_equal(allgrads, expected_grads, tol=10 * tol)
917
- tape.zero()
918
-
919
-
920
- def test_scalar_multiplication_rightmul(test, device, dtype, register_kernels=False):
921
- rng = np.random.default_rng(123)
922
-
923
- tol = {
924
- np.float16: 5.0e-3,
925
- np.float32: 1.0e-6,
926
- np.float64: 1.0e-8,
927
- }.get(dtype, 0)
928
-
929
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
930
- vec2 = wp.types.vector(length=2, dtype=wptype)
931
- vec3 = wp.types.vector(length=3, dtype=wptype)
932
- vec4 = wp.types.vector(length=4, dtype=wptype)
933
- vec5 = wp.types.vector(length=5, dtype=wptype)
934
-
935
- def check_rightmul(
936
- s: wp.array(dtype=wptype),
937
- v2: wp.array(dtype=vec2),
938
- v3: wp.array(dtype=vec3),
939
- v4: wp.array(dtype=vec4),
940
- v5: wp.array(dtype=vec5),
941
- v20: wp.array(dtype=wptype),
942
- v21: wp.array(dtype=wptype),
943
- v30: wp.array(dtype=wptype),
944
- v31: wp.array(dtype=wptype),
945
- v32: wp.array(dtype=wptype),
946
- v40: wp.array(dtype=wptype),
947
- v41: wp.array(dtype=wptype),
948
- v42: wp.array(dtype=wptype),
949
- v43: wp.array(dtype=wptype),
950
- v50: wp.array(dtype=wptype),
951
- v51: wp.array(dtype=wptype),
952
- v52: wp.array(dtype=wptype),
953
- v53: wp.array(dtype=wptype),
954
- v54: wp.array(dtype=wptype),
955
- ):
956
- v2result = v2[0] * s[0]
957
- v3result = v3[0] * s[0]
958
- v4result = v4[0] * s[0]
959
- v5result = v5[0] * s[0]
960
-
961
- # multiply outputs by 2 so we've got something to backpropagate:
962
- v20[0] = wptype(2) * v2result[0]
963
- v21[0] = wptype(2) * v2result[1]
964
-
965
- v30[0] = wptype(2) * v3result[0]
966
- v31[0] = wptype(2) * v3result[1]
967
- v32[0] = wptype(2) * v3result[2]
968
-
969
- v40[0] = wptype(2) * v4result[0]
970
- v41[0] = wptype(2) * v4result[1]
971
- v42[0] = wptype(2) * v4result[2]
972
- v43[0] = wptype(2) * v4result[3]
973
-
974
- v50[0] = wptype(2) * v5result[0]
975
- v51[0] = wptype(2) * v5result[1]
976
- v52[0] = wptype(2) * v5result[2]
977
- v53[0] = wptype(2) * v5result[3]
978
- v54[0] = wptype(2) * v5result[4]
979
-
980
- kernel = getkernel(check_rightmul, suffix=dtype.__name__)
981
-
982
- if register_kernels:
983
- return
984
-
985
- s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
986
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
987
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
988
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
989
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
990
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
991
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
992
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
993
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
994
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
995
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
996
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
997
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
998
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
999
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1000
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1001
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1002
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1003
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1004
- tape = wp.Tape()
1005
- with tape:
1006
- wp.launch(
1007
- kernel,
1008
- dim=1,
1009
- inputs=[
1010
- s,
1011
- v2,
1012
- v3,
1013
- v4,
1014
- v5,
1015
- ],
1016
- outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1017
- device=device,
1018
- )
1019
-
1020
- assert_np_equal(v20.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 0], tol=tol)
1021
- assert_np_equal(v21.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 1], tol=tol)
1022
-
1023
- assert_np_equal(v30.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 0], tol=10 * tol)
1024
- assert_np_equal(v31.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 1], tol=10 * tol)
1025
- assert_np_equal(v32.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 2], tol=10 * tol)
1026
-
1027
- assert_np_equal(v40.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 0], tol=10 * tol)
1028
- assert_np_equal(v41.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 1], tol=10 * tol)
1029
- assert_np_equal(v42.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 2], tol=10 * tol)
1030
- assert_np_equal(v43.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 3], tol=10 * tol)
1031
-
1032
- assert_np_equal(v50.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 0], tol=10 * tol)
1033
- assert_np_equal(v51.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 1], tol=10 * tol)
1034
- assert_np_equal(v52.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 2], tol=10 * tol)
1035
- assert_np_equal(v53.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 3], tol=10 * tol)
1036
- assert_np_equal(v54.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 4], tol=10 * tol)
1037
-
1038
- incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1039
-
1040
- if dtype in np_float_types:
1041
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43]):
1042
- tape.backward(loss=l)
1043
- sgrad = tape.gradients[s].numpy()[0]
1044
- assert_np_equal(sgrad, 2 * incmps[i], tol=10 * tol)
1045
- allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4]])
1046
- expected_grads = np.zeros_like(allgrads)
1047
- expected_grads[i] = s.numpy()[0] * 2
1048
- assert_np_equal(allgrads, expected_grads, tol=10 * tol)
1049
- tape.zero()
1050
-
1051
-
1052
- def test_cw_multiplication(test, device, dtype, register_kernels=False):
1053
- rng = np.random.default_rng(123)
1054
-
1055
- tol = {
1056
- np.float16: 5.0e-3,
1057
- np.float32: 1.0e-6,
1058
- np.float64: 1.0e-8,
1059
- }.get(dtype, 0)
1060
-
1061
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1062
- vec2 = wp.types.vector(length=2, dtype=wptype)
1063
- vec3 = wp.types.vector(length=3, dtype=wptype)
1064
- vec4 = wp.types.vector(length=4, dtype=wptype)
1065
- vec5 = wp.types.vector(length=5, dtype=wptype)
1066
-
1067
- def check_cw_mul(
1068
- s2: wp.array(dtype=vec2),
1069
- s3: wp.array(dtype=vec3),
1070
- s4: wp.array(dtype=vec4),
1071
- s5: wp.array(dtype=vec5),
1072
- v2: wp.array(dtype=vec2),
1073
- v3: wp.array(dtype=vec3),
1074
- v4: wp.array(dtype=vec4),
1075
- v5: wp.array(dtype=vec5),
1076
- v20: wp.array(dtype=wptype),
1077
- v21: wp.array(dtype=wptype),
1078
- v30: wp.array(dtype=wptype),
1079
- v31: wp.array(dtype=wptype),
1080
- v32: wp.array(dtype=wptype),
1081
- v40: wp.array(dtype=wptype),
1082
- v41: wp.array(dtype=wptype),
1083
- v42: wp.array(dtype=wptype),
1084
- v43: wp.array(dtype=wptype),
1085
- v50: wp.array(dtype=wptype),
1086
- v51: wp.array(dtype=wptype),
1087
- v52: wp.array(dtype=wptype),
1088
- v53: wp.array(dtype=wptype),
1089
- v54: wp.array(dtype=wptype),
1090
- ):
1091
- v2result = wp.cw_mul(s2[0], v2[0])
1092
- v3result = wp.cw_mul(s3[0], v3[0])
1093
- v4result = wp.cw_mul(s4[0], v4[0])
1094
- v5result = wp.cw_mul(s5[0], v5[0])
1095
-
1096
- v20[0] = wptype(2) * v2result[0]
1097
- v21[0] = wptype(2) * v2result[1]
1098
-
1099
- v30[0] = wptype(2) * v3result[0]
1100
- v31[0] = wptype(2) * v3result[1]
1101
- v32[0] = wptype(2) * v3result[2]
1102
-
1103
- v40[0] = wptype(2) * v4result[0]
1104
- v41[0] = wptype(2) * v4result[1]
1105
- v42[0] = wptype(2) * v4result[2]
1106
- v43[0] = wptype(2) * v4result[3]
1107
-
1108
- v50[0] = wptype(2) * v5result[0]
1109
- v51[0] = wptype(2) * v5result[1]
1110
- v52[0] = wptype(2) * v5result[2]
1111
- v53[0] = wptype(2) * v5result[3]
1112
- v54[0] = wptype(2) * v5result[4]
1113
-
1114
- kernel = getkernel(check_cw_mul, suffix=dtype.__name__)
1115
-
1116
- if register_kernels:
1117
- return
1118
-
1119
- s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1120
- s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1121
- s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1122
- s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1123
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1124
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1125
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1126
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1127
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1128
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1129
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1130
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1131
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1132
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1133
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1134
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1135
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1136
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1137
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1138
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1139
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1140
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1141
- tape = wp.Tape()
1142
- with tape:
1143
- wp.launch(
1144
- kernel,
1145
- dim=1,
1146
- inputs=[
1147
- s2,
1148
- s3,
1149
- s4,
1150
- s5,
1151
- v2,
1152
- v3,
1153
- v4,
1154
- v5,
1155
- ],
1156
- outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1157
- device=device,
1158
- )
1159
-
1160
- assert_np_equal(v20.numpy()[0], 2 * s2.numpy()[0, 0] * v2.numpy()[0, 0], tol=10 * tol)
1161
- assert_np_equal(v21.numpy()[0], 2 * s2.numpy()[0, 1] * v2.numpy()[0, 1], tol=10 * tol)
1162
-
1163
- assert_np_equal(v30.numpy()[0], 2 * s3.numpy()[0, 0] * v3.numpy()[0, 0], tol=10 * tol)
1164
- assert_np_equal(v31.numpy()[0], 2 * s3.numpy()[0, 1] * v3.numpy()[0, 1], tol=10 * tol)
1165
- assert_np_equal(v32.numpy()[0], 2 * s3.numpy()[0, 2] * v3.numpy()[0, 2], tol=10 * tol)
1166
-
1167
- assert_np_equal(v40.numpy()[0], 2 * s4.numpy()[0, 0] * v4.numpy()[0, 0], tol=10 * tol)
1168
- assert_np_equal(v41.numpy()[0], 2 * s4.numpy()[0, 1] * v4.numpy()[0, 1], tol=10 * tol)
1169
- assert_np_equal(v42.numpy()[0], 2 * s4.numpy()[0, 2] * v4.numpy()[0, 2], tol=10 * tol)
1170
- assert_np_equal(v43.numpy()[0], 2 * s4.numpy()[0, 3] * v4.numpy()[0, 3], tol=10 * tol)
1171
-
1172
- assert_np_equal(v50.numpy()[0], 2 * s5.numpy()[0, 0] * v5.numpy()[0, 0], tol=10 * tol)
1173
- assert_np_equal(v51.numpy()[0], 2 * s5.numpy()[0, 1] * v5.numpy()[0, 1], tol=10 * tol)
1174
- assert_np_equal(v52.numpy()[0], 2 * s5.numpy()[0, 2] * v5.numpy()[0, 2], tol=10 * tol)
1175
- assert_np_equal(v53.numpy()[0], 2 * s5.numpy()[0, 3] * v5.numpy()[0, 3], tol=10 * tol)
1176
- assert_np_equal(v54.numpy()[0], 2 * s5.numpy()[0, 4] * v5.numpy()[0, 4], tol=10 * tol)
1177
-
1178
- incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1179
- scmps = np.concatenate([v.numpy()[0] for v in [s2, s3, s4, s5]])
1180
-
1181
- if dtype in np_float_types:
1182
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1183
- tape.backward(loss=l)
1184
- sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1185
- expected_grads = np.zeros_like(sgrads)
1186
- expected_grads[i] = incmps[i] * 2
1187
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1188
-
1189
- allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1190
- expected_grads = np.zeros_like(allgrads)
1191
- expected_grads[i] = scmps[i] * 2
1192
- assert_np_equal(allgrads, expected_grads, tol=10 * tol)
1193
-
1194
- tape.zero()
1195
-
1196
-
1197
- def test_scalar_division(test, device, dtype, register_kernels=False):
1198
- rng = np.random.default_rng(123)
1199
-
1200
- tol = {
1201
- np.float16: 5.0e-3,
1202
- np.float32: 1.0e-6,
1203
- np.float64: 1.0e-8,
1204
- }.get(dtype, 0)
1205
-
1206
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1207
- vec2 = wp.types.vector(length=2, dtype=wptype)
1208
- vec3 = wp.types.vector(length=3, dtype=wptype)
1209
- vec4 = wp.types.vector(length=4, dtype=wptype)
1210
- vec5 = wp.types.vector(length=5, dtype=wptype)
1211
-
1212
- def check_div(
1213
- s: wp.array(dtype=wptype),
1214
- v2: wp.array(dtype=vec2),
1215
- v3: wp.array(dtype=vec3),
1216
- v4: wp.array(dtype=vec4),
1217
- v5: wp.array(dtype=vec5),
1218
- v20: wp.array(dtype=wptype),
1219
- v21: wp.array(dtype=wptype),
1220
- v30: wp.array(dtype=wptype),
1221
- v31: wp.array(dtype=wptype),
1222
- v32: wp.array(dtype=wptype),
1223
- v40: wp.array(dtype=wptype),
1224
- v41: wp.array(dtype=wptype),
1225
- v42: wp.array(dtype=wptype),
1226
- v43: wp.array(dtype=wptype),
1227
- v50: wp.array(dtype=wptype),
1228
- v51: wp.array(dtype=wptype),
1229
- v52: wp.array(dtype=wptype),
1230
- v53: wp.array(dtype=wptype),
1231
- v54: wp.array(dtype=wptype),
1232
- ):
1233
- v2result = v2[0] / s[0]
1234
- v3result = v3[0] / s[0]
1235
- v4result = v4[0] / s[0]
1236
- v5result = v5[0] / s[0]
1237
-
1238
- v20[0] = wptype(2) * v2result[0]
1239
- v21[0] = wptype(2) * v2result[1]
1240
-
1241
- v30[0] = wptype(2) * v3result[0]
1242
- v31[0] = wptype(2) * v3result[1]
1243
- v32[0] = wptype(2) * v3result[2]
1244
-
1245
- v40[0] = wptype(2) * v4result[0]
1246
- v41[0] = wptype(2) * v4result[1]
1247
- v42[0] = wptype(2) * v4result[2]
1248
- v43[0] = wptype(2) * v4result[3]
1249
-
1250
- v50[0] = wptype(2) * v5result[0]
1251
- v51[0] = wptype(2) * v5result[1]
1252
- v52[0] = wptype(2) * v5result[2]
1253
- v53[0] = wptype(2) * v5result[3]
1254
- v54[0] = wptype(2) * v5result[4]
1255
-
1256
- kernel = getkernel(check_div, suffix=dtype.__name__)
1257
-
1258
- if register_kernels:
1259
- return
1260
-
1261
- s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
1262
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1263
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1264
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1265
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1266
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1267
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1268
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1269
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1270
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1271
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1272
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1273
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1274
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1275
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1276
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1277
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1278
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1279
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1280
- tape = wp.Tape()
1281
- with tape:
1282
- wp.launch(
1283
- kernel,
1284
- dim=1,
1285
- inputs=[
1286
- s,
1287
- v2,
1288
- v3,
1289
- v4,
1290
- v5,
1291
- ],
1292
- outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1293
- device=device,
1294
- )
1295
-
1296
- if dtype in np_int_types:
1297
- assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] // (s.numpy()[0])), tol=tol)
1298
- assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] // (s.numpy()[0])), tol=tol)
1299
-
1300
- assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1301
- assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1302
- assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1303
-
1304
- assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1305
- assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1306
- assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1307
- assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] // (s.numpy()[0])), tol=10 * tol)
1308
-
1309
- assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1310
- assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1311
- assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1312
- assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] // (s.numpy()[0])), tol=10 * tol)
1313
- assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] // (s.numpy()[0])), tol=10 * tol)
1314
-
1315
- else:
1316
- assert_np_equal(v20.numpy()[0], 2 * v2.numpy()[0, 0] / (s.numpy()[0]), tol=tol)
1317
- assert_np_equal(v21.numpy()[0], 2 * v2.numpy()[0, 1] / (s.numpy()[0]), tol=tol)
1318
-
1319
- assert_np_equal(v30.numpy()[0], 2 * v3.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1320
- assert_np_equal(v31.numpy()[0], 2 * v3.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1321
- assert_np_equal(v32.numpy()[0], 2 * v3.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1322
-
1323
- assert_np_equal(v40.numpy()[0], 2 * v4.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1324
- assert_np_equal(v41.numpy()[0], 2 * v4.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1325
- assert_np_equal(v42.numpy()[0], 2 * v4.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1326
- assert_np_equal(v43.numpy()[0], 2 * v4.numpy()[0, 3] / (s.numpy()[0]), tol=10 * tol)
1327
-
1328
- assert_np_equal(v50.numpy()[0], 2 * v5.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1329
- assert_np_equal(v51.numpy()[0], 2 * v5.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1330
- assert_np_equal(v52.numpy()[0], 2 * v5.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1331
- assert_np_equal(v53.numpy()[0], 2 * v5.numpy()[0, 3] / (s.numpy()[0]), tol=10 * tol)
1332
- assert_np_equal(v54.numpy()[0], 2 * v5.numpy()[0, 4] / (s.numpy()[0]), tol=10 * tol)
1333
-
1334
- incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1335
-
1336
- if dtype in np_float_types:
1337
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1338
- tape.backward(loss=l)
1339
- sgrad = tape.gradients[s].numpy()[0]
1340
-
1341
- # d/ds v/s = -v/s^2
1342
- assert_np_equal(sgrad, -2 * incmps[i] / (s.numpy()[0] * s.numpy()[0]), tol=10 * tol)
1343
-
1344
- allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1345
- expected_grads = np.zeros_like(allgrads)
1346
- expected_grads[i] = 2 / s.numpy()[0]
1347
-
1348
- # d/dv v/s = 1/s
1349
- assert_np_equal(allgrads, expected_grads, tol=tol)
1350
- tape.zero()
1351
-
1352
-
1353
- def test_cw_division(test, device, dtype, register_kernels=False):
1354
- rng = np.random.default_rng(123)
1355
-
1356
- tol = {
1357
- np.float16: 1.0e-2,
1358
- np.float32: 1.0e-6,
1359
- np.float64: 1.0e-8,
1360
- }.get(dtype, 0)
1361
-
1362
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1363
- vec2 = wp.types.vector(length=2, dtype=wptype)
1364
- vec3 = wp.types.vector(length=3, dtype=wptype)
1365
- vec4 = wp.types.vector(length=4, dtype=wptype)
1366
- vec5 = wp.types.vector(length=5, dtype=wptype)
1367
-
1368
- def check_cw_div(
1369
- s2: wp.array(dtype=vec2),
1370
- s3: wp.array(dtype=vec3),
1371
- s4: wp.array(dtype=vec4),
1372
- s5: wp.array(dtype=vec5),
1373
- v2: wp.array(dtype=vec2),
1374
- v3: wp.array(dtype=vec3),
1375
- v4: wp.array(dtype=vec4),
1376
- v5: wp.array(dtype=vec5),
1377
- v20: wp.array(dtype=wptype),
1378
- v21: wp.array(dtype=wptype),
1379
- v30: wp.array(dtype=wptype),
1380
- v31: wp.array(dtype=wptype),
1381
- v32: wp.array(dtype=wptype),
1382
- v40: wp.array(dtype=wptype),
1383
- v41: wp.array(dtype=wptype),
1384
- v42: wp.array(dtype=wptype),
1385
- v43: wp.array(dtype=wptype),
1386
- v50: wp.array(dtype=wptype),
1387
- v51: wp.array(dtype=wptype),
1388
- v52: wp.array(dtype=wptype),
1389
- v53: wp.array(dtype=wptype),
1390
- v54: wp.array(dtype=wptype),
1391
- ):
1392
- v2result = wp.cw_div(v2[0], s2[0])
1393
- v3result = wp.cw_div(v3[0], s3[0])
1394
- v4result = wp.cw_div(v4[0], s4[0])
1395
- v5result = wp.cw_div(v5[0], s5[0])
1396
-
1397
- v20[0] = wptype(2) * v2result[0]
1398
- v21[0] = wptype(2) * v2result[1]
1399
-
1400
- v30[0] = wptype(2) * v3result[0]
1401
- v31[0] = wptype(2) * v3result[1]
1402
- v32[0] = wptype(2) * v3result[2]
1403
-
1404
- v40[0] = wptype(2) * v4result[0]
1405
- v41[0] = wptype(2) * v4result[1]
1406
- v42[0] = wptype(2) * v4result[2]
1407
- v43[0] = wptype(2) * v4result[3]
1408
-
1409
- v50[0] = wptype(2) * v5result[0]
1410
- v51[0] = wptype(2) * v5result[1]
1411
- v52[0] = wptype(2) * v5result[2]
1412
- v53[0] = wptype(2) * v5result[3]
1413
- v54[0] = wptype(2) * v5result[4]
1414
-
1415
- kernel = getkernel(check_cw_div, suffix=dtype.__name__)
1416
-
1417
- if register_kernels:
1418
- return
1419
-
1420
- s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1421
- s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1422
- s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1423
- s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1424
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1425
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1426
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1427
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1428
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1429
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1430
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1431
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1432
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1433
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1434
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1435
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1436
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1437
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1438
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1439
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1440
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1441
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1442
- tape = wp.Tape()
1443
- with tape:
1444
- wp.launch(
1445
- kernel,
1446
- dim=1,
1447
- inputs=[
1448
- s2,
1449
- s3,
1450
- s4,
1451
- s5,
1452
- v2,
1453
- v3,
1454
- v4,
1455
- v5,
1456
- ],
1457
- outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1458
- device=device,
1459
- )
1460
-
1461
- if dtype in np_int_types:
1462
- assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] // s2.numpy()[0, 0]), tol=tol)
1463
- assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] // s2.numpy()[0, 1]), tol=tol)
1464
-
1465
- assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] // s3.numpy()[0, 0]), tol=tol)
1466
- assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] // s3.numpy()[0, 1]), tol=tol)
1467
- assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] // s3.numpy()[0, 2]), tol=tol)
1468
-
1469
- assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] // s4.numpy()[0, 0]), tol=tol)
1470
- assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] // s4.numpy()[0, 1]), tol=tol)
1471
- assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] // s4.numpy()[0, 2]), tol=tol)
1472
- assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] // s4.numpy()[0, 3]), tol=tol)
1473
-
1474
- assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] // s5.numpy()[0, 0]), tol=tol)
1475
- assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] // s5.numpy()[0, 1]), tol=tol)
1476
- assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] // s5.numpy()[0, 2]), tol=tol)
1477
- assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] // s5.numpy()[0, 3]), tol=tol)
1478
- assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] // s5.numpy()[0, 4]), tol=tol)
1479
- else:
1480
- assert_np_equal(v20.numpy()[0], 2 * v2.numpy()[0, 0] / s2.numpy()[0, 0], tol=tol)
1481
- assert_np_equal(v21.numpy()[0], 2 * v2.numpy()[0, 1] / s2.numpy()[0, 1], tol=tol)
1482
-
1483
- assert_np_equal(v30.numpy()[0], 2 * v3.numpy()[0, 0] / s3.numpy()[0, 0], tol=tol)
1484
- assert_np_equal(v31.numpy()[0], 2 * v3.numpy()[0, 1] / s3.numpy()[0, 1], tol=tol)
1485
- assert_np_equal(v32.numpy()[0], 2 * v3.numpy()[0, 2] / s3.numpy()[0, 2], tol=tol)
1486
-
1487
- assert_np_equal(v40.numpy()[0], 2 * v4.numpy()[0, 0] / s4.numpy()[0, 0], tol=tol)
1488
- assert_np_equal(v41.numpy()[0], 2 * v4.numpy()[0, 1] / s4.numpy()[0, 1], tol=tol)
1489
- assert_np_equal(v42.numpy()[0], 2 * v4.numpy()[0, 2] / s4.numpy()[0, 2], tol=tol)
1490
- assert_np_equal(v43.numpy()[0], 2 * v4.numpy()[0, 3] / s4.numpy()[0, 3], tol=tol)
1491
-
1492
- assert_np_equal(v50.numpy()[0], 2 * v5.numpy()[0, 0] / s5.numpy()[0, 0], tol=tol)
1493
- assert_np_equal(v51.numpy()[0], 2 * v5.numpy()[0, 1] / s5.numpy()[0, 1], tol=tol)
1494
- assert_np_equal(v52.numpy()[0], 2 * v5.numpy()[0, 2] / s5.numpy()[0, 2], tol=tol)
1495
- assert_np_equal(v53.numpy()[0], 2 * v5.numpy()[0, 3] / s5.numpy()[0, 3], tol=tol)
1496
- assert_np_equal(v54.numpy()[0], 2 * v5.numpy()[0, 4] / s5.numpy()[0, 4], tol=tol)
1497
-
1498
- if dtype in np_float_types:
1499
- incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1500
- scmps = np.concatenate([v.numpy()[0] for v in [s2, s3, s4, s5]])
1501
-
1502
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1503
- tape.backward(loss=l)
1504
- sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1505
- expected_grads = np.zeros_like(sgrads)
1506
-
1507
- # d/ds v/s = -v/s^2
1508
- expected_grads[i] = -incmps[i] * 2 / (scmps[i] * scmps[i])
1509
- assert_np_equal(sgrads, expected_grads, tol=20 * tol)
1510
-
1511
- allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1512
- expected_grads = np.zeros_like(allgrads)
1513
-
1514
- # d/dv v/s = 1/s
1515
- expected_grads[i] = 2 / scmps[i]
1516
- assert_np_equal(allgrads, expected_grads, tol=tol)
1517
-
1518
- tape.zero()
1519
-
1520
-
1521
- def test_addition(test, device, dtype, register_kernels=False):
1522
- rng = np.random.default_rng(123)
1523
-
1524
- tol = {
1525
- np.float16: 5.0e-3,
1526
- np.float32: 1.0e-6,
1527
- np.float64: 1.0e-8,
1528
- }.get(dtype, 0)
1529
-
1530
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1531
- vec2 = wp.types.vector(length=2, dtype=wptype)
1532
- vec3 = wp.types.vector(length=3, dtype=wptype)
1533
- vec4 = wp.types.vector(length=4, dtype=wptype)
1534
- vec5 = wp.types.vector(length=5, dtype=wptype)
1535
-
1536
- def check_add(
1537
- s2: wp.array(dtype=vec2),
1538
- s3: wp.array(dtype=vec3),
1539
- s4: wp.array(dtype=vec4),
1540
- s5: wp.array(dtype=vec5),
1541
- v2: wp.array(dtype=vec2),
1542
- v3: wp.array(dtype=vec3),
1543
- v4: wp.array(dtype=vec4),
1544
- v5: wp.array(dtype=vec5),
1545
- v20: wp.array(dtype=wptype),
1546
- v21: wp.array(dtype=wptype),
1547
- v30: wp.array(dtype=wptype),
1548
- v31: wp.array(dtype=wptype),
1549
- v32: wp.array(dtype=wptype),
1550
- v40: wp.array(dtype=wptype),
1551
- v41: wp.array(dtype=wptype),
1552
- v42: wp.array(dtype=wptype),
1553
- v43: wp.array(dtype=wptype),
1554
- v50: wp.array(dtype=wptype),
1555
- v51: wp.array(dtype=wptype),
1556
- v52: wp.array(dtype=wptype),
1557
- v53: wp.array(dtype=wptype),
1558
- v54: wp.array(dtype=wptype),
1559
- ):
1560
- v2result = v2[0] + s2[0]
1561
- v3result = v3[0] + s3[0]
1562
- v4result = v4[0] + s4[0]
1563
- v5result = v5[0] + s5[0]
1564
-
1565
- v20[0] = wptype(2) * v2result[0]
1566
- v21[0] = wptype(2) * v2result[1]
1567
-
1568
- v30[0] = wptype(2) * v3result[0]
1569
- v31[0] = wptype(2) * v3result[1]
1570
- v32[0] = wptype(2) * v3result[2]
1571
-
1572
- v40[0] = wptype(2) * v4result[0]
1573
- v41[0] = wptype(2) * v4result[1]
1574
- v42[0] = wptype(2) * v4result[2]
1575
- v43[0] = wptype(2) * v4result[3]
1576
-
1577
- v50[0] = wptype(2) * v5result[0]
1578
- v51[0] = wptype(2) * v5result[1]
1579
- v52[0] = wptype(2) * v5result[2]
1580
- v53[0] = wptype(2) * v5result[3]
1581
- v54[0] = wptype(2) * v5result[4]
1582
-
1583
- kernel = getkernel(check_add, suffix=dtype.__name__)
1584
-
1585
- if register_kernels:
1586
- return
1587
-
1588
- s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1589
- s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1590
- s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1591
- s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1592
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1593
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1594
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1595
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1596
- v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1597
- v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1598
- v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1599
- v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1600
- v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1601
- v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1602
- v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1603
- v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1604
- v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1605
- v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1606
- v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1607
- v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1608
- v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1609
- v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1610
- tape = wp.Tape()
1611
- with tape:
1612
- wp.launch(
1613
- kernel,
1614
- dim=1,
1615
- inputs=[
1616
- s2,
1617
- s3,
1618
- s4,
1619
- s5,
1620
- v2,
1621
- v3,
1622
- v4,
1623
- v5,
1624
- ],
1625
- outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1626
- device=device,
1627
- )
1628
-
1629
- assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] + s2.numpy()[0, 0]), tol=tol)
1630
- assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] + s2.numpy()[0, 1]), tol=tol)
1631
-
1632
- assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] + s3.numpy()[0, 0]), tol=tol)
1633
- assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] + s3.numpy()[0, 1]), tol=tol)
1634
- assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] + s3.numpy()[0, 2]), tol=tol)
1635
-
1636
- assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] + s4.numpy()[0, 0]), tol=tol)
1637
- assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] + s4.numpy()[0, 1]), tol=tol)
1638
- assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] + s4.numpy()[0, 2]), tol=tol)
1639
- assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] + s4.numpy()[0, 3]), tol=tol)
1640
-
1641
- assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] + s5.numpy()[0, 0]), tol=tol)
1642
- assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] + s5.numpy()[0, 1]), tol=tol)
1643
- assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] + s5.numpy()[0, 2]), tol=tol)
1644
- assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] + s5.numpy()[0, 3]), tol=tol)
1645
- assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] + s5.numpy()[0, 4]), tol=2 * tol)
1646
-
1647
- if dtype in np_float_types:
1648
- for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1649
- tape.backward(loss=l)
1650
- sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1651
- expected_grads = np.zeros_like(sgrads)
1652
-
1653
- expected_grads[i] = 2
1654
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1655
-
1656
- allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1657
- assert_np_equal(allgrads, expected_grads, tol=tol)
1658
-
1659
- tape.zero()
1660
-
1661
-
1662
- def test_dotproduct(test, device, dtype, register_kernels=False):
1663
- rng = np.random.default_rng(123)
1664
-
1665
- tol = {
1666
- np.float16: 1.0e-2,
1667
- np.float32: 1.0e-6,
1668
- np.float64: 1.0e-8,
1669
- }.get(dtype, 0)
1670
-
1671
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1672
- vec2 = wp.types.vector(length=2, dtype=wptype)
1673
- vec3 = wp.types.vector(length=3, dtype=wptype)
1674
- vec4 = wp.types.vector(length=4, dtype=wptype)
1675
- vec5 = wp.types.vector(length=5, dtype=wptype)
1676
-
1677
- def check_dot(
1678
- s2: wp.array(dtype=vec2),
1679
- s3: wp.array(dtype=vec3),
1680
- s4: wp.array(dtype=vec4),
1681
- s5: wp.array(dtype=vec5),
1682
- v2: wp.array(dtype=vec2),
1683
- v3: wp.array(dtype=vec3),
1684
- v4: wp.array(dtype=vec4),
1685
- v5: wp.array(dtype=vec5),
1686
- dot2: wp.array(dtype=wptype),
1687
- dot3: wp.array(dtype=wptype),
1688
- dot4: wp.array(dtype=wptype),
1689
- dot5: wp.array(dtype=wptype),
1690
- ):
1691
- dot2[0] = wptype(2) * wp.dot(v2[0], s2[0])
1692
- dot3[0] = wptype(2) * wp.dot(v3[0], s3[0])
1693
- dot4[0] = wptype(2) * wp.dot(v4[0], s4[0])
1694
- dot5[0] = wptype(2) * wp.dot(v5[0], s5[0])
1695
-
1696
- kernel = getkernel(check_dot, suffix=dtype.__name__)
1697
-
1698
- if register_kernels:
1699
- return
1700
-
1701
- s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1702
- s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1703
- s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1704
- s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1705
- v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1706
- v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1707
- v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1708
- v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1709
- dot2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1710
- dot3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1711
- dot4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1712
- dot5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1713
- tape = wp.Tape()
1714
- with tape:
1715
- wp.launch(
1716
- kernel,
1717
- dim=1,
1718
- inputs=[
1719
- s2,
1720
- s3,
1721
- s4,
1722
- s5,
1723
- v2,
1724
- v3,
1725
- v4,
1726
- v5,
1727
- ],
1728
- outputs=[dot2, dot3, dot4, dot5],
1729
- device=device,
1730
- )
1731
-
1732
- assert_np_equal(dot2.numpy()[0], 2.0 * (v2.numpy() * s2.numpy()).sum(), tol=10 * tol)
1733
- assert_np_equal(dot3.numpy()[0], 2.0 * (v3.numpy() * s3.numpy()).sum(), tol=10 * tol)
1734
- assert_np_equal(dot4.numpy()[0], 2.0 * (v4.numpy() * s4.numpy()).sum(), tol=10 * tol)
1735
- assert_np_equal(dot5.numpy()[0], 2.0 * (v5.numpy() * s5.numpy()).sum(), tol=10 * tol)
1736
-
1737
- if dtype in np_float_types:
1738
- tape.backward(loss=dot2)
1739
- sgrads = tape.gradients[s2].numpy()[0]
1740
- expected_grads = 2.0 * v2.numpy()[0]
1741
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1742
-
1743
- vgrads = tape.gradients[v2].numpy()[0]
1744
- expected_grads = 2.0 * s2.numpy()[0]
1745
- assert_np_equal(vgrads, expected_grads, tol=tol)
1746
-
1747
- tape.zero()
1748
-
1749
- tape.backward(loss=dot3)
1750
- sgrads = tape.gradients[s3].numpy()[0]
1751
- expected_grads = 2.0 * v3.numpy()[0]
1752
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1753
-
1754
- vgrads = tape.gradients[v3].numpy()[0]
1755
- expected_grads = 2.0 * s3.numpy()[0]
1756
- assert_np_equal(vgrads, expected_grads, tol=tol)
1757
-
1758
- tape.zero()
1759
-
1760
- tape.backward(loss=dot4)
1761
- sgrads = tape.gradients[s4].numpy()[0]
1762
- expected_grads = 2.0 * v4.numpy()[0]
1763
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1764
-
1765
- vgrads = tape.gradients[v4].numpy()[0]
1766
- expected_grads = 2.0 * s4.numpy()[0]
1767
- assert_np_equal(vgrads, expected_grads, tol=tol)
1768
-
1769
- tape.zero()
1770
-
1771
- tape.backward(loss=dot5)
1772
- sgrads = tape.gradients[s5].numpy()[0]
1773
- expected_grads = 2.0 * v5.numpy()[0]
1774
- assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1775
-
1776
- vgrads = tape.gradients[v5].numpy()[0]
1777
- expected_grads = 2.0 * s5.numpy()[0]
1778
- assert_np_equal(vgrads, expected_grads, tol=10 * tol)
1779
-
1780
- tape.zero()
1781
-
1782
-
1783
- def test_equivalent_types(test, device, dtype, register_kernels=False):
1784
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1785
-
1786
- # vector types
1787
- vec2 = wp.types.vector(length=2, dtype=wptype)
1788
- vec3 = wp.types.vector(length=3, dtype=wptype)
1789
- vec4 = wp.types.vector(length=4, dtype=wptype)
1790
- vec5 = wp.types.vector(length=5, dtype=wptype)
1791
-
1792
- # vector types equivalent to the above
1793
- vec2_equiv = wp.types.vector(length=2, dtype=wptype)
1794
- vec3_equiv = wp.types.vector(length=3, dtype=wptype)
1795
- vec4_equiv = wp.types.vector(length=4, dtype=wptype)
1796
- vec5_equiv = wp.types.vector(length=5, dtype=wptype)
1797
-
1798
- # declare kernel with original types
1799
- def check_equivalence(
1800
- v2: vec2,
1801
- v3: vec3,
1802
- v4: vec4,
1803
- v5: vec5,
1804
- ):
1805
- wp.expect_eq(v2, vec2(wptype(1), wptype(2)))
1806
- wp.expect_eq(v3, vec3(wptype(1), wptype(2), wptype(3)))
1807
- wp.expect_eq(v4, vec4(wptype(1), wptype(2), wptype(3), wptype(4)))
1808
- wp.expect_eq(v5, vec5(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1809
-
1810
- wp.expect_eq(v2, vec2_equiv(wptype(1), wptype(2)))
1811
- wp.expect_eq(v3, vec3_equiv(wptype(1), wptype(2), wptype(3)))
1812
- wp.expect_eq(v4, vec4_equiv(wptype(1), wptype(2), wptype(3), wptype(4)))
1813
- wp.expect_eq(v5, vec5_equiv(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1814
-
1815
- kernel = getkernel(check_equivalence, suffix=dtype.__name__)
1816
-
1817
- if register_kernels:
1818
- return
1819
-
1820
- # call kernel with equivalent types
1821
- v2 = vec2_equiv(1, 2)
1822
- v3 = vec3_equiv(1, 2, 3)
1823
- v4 = vec4_equiv(1, 2, 3, 4)
1824
- v5 = vec5_equiv(1, 2, 3, 4, 5)
1825
-
1826
- wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5], device=device)
1827
-
1828
-
1829
- def test_conversions(test, device, dtype, register_kernels=False):
1830
- def check_vectors_equal(
1831
- v0: wp.vec3,
1832
- v1: wp.vec3,
1833
- v2: wp.vec3,
1834
- v3: wp.vec3,
1835
- ):
1836
- wp.expect_eq(v1, v0)
1837
- wp.expect_eq(v2, v0)
1838
- wp.expect_eq(v3, v0)
1839
-
1840
- kernel = getkernel(check_vectors_equal, suffix=dtype.__name__)
1841
-
1842
- if register_kernels:
1843
- return
1844
-
1845
- v0 = wp.vec3(1, 2, 3)
1846
-
1847
- # test explicit conversions - constructing vectors from different containers
1848
- v1 = wp.vec3((1, 2, 3))
1849
- v2 = wp.vec3([1, 2, 3])
1850
- v3 = wp.vec3(np.array([1, 2, 3], dtype=dtype))
1851
-
1852
- wp.launch(kernel, dim=1, inputs=[v0, v1, v2, v3], device=device)
1853
-
1854
- # test implicit conversions - passing different containers as vectors to wp.launch()
1855
- v1 = (1, 2, 3)
1856
- v2 = [1, 2, 3]
1857
- v3 = np.array([1, 2, 3], dtype=dtype)
1858
-
1859
- wp.launch(kernel, dim=1, inputs=[v0, v1, v2, v3], device=device)
1860
-
1861
-
1862
- def test_constants(test, device, dtype, register_kernels=False):
1863
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1864
- vec2 = wp.types.vector(length=2, dtype=wptype)
1865
- vec3 = wp.types.vector(length=3, dtype=wptype)
1866
- vec4 = wp.types.vector(length=4, dtype=wptype)
1867
- vec5 = wp.types.vector(length=5, dtype=wptype)
1868
-
1869
- cv2 = wp.constant(vec2(1, 2))
1870
- cv3 = wp.constant(vec3(1, 2, 3))
1871
- cv4 = wp.constant(vec4(1, 2, 3, 4))
1872
- cv5 = wp.constant(vec5(1, 2, 3, 4, 5))
1873
-
1874
- def check_vector_constants():
1875
- wp.expect_eq(cv2, vec2(wptype(1), wptype(2)))
1876
- wp.expect_eq(cv3, vec3(wptype(1), wptype(2), wptype(3)))
1877
- wp.expect_eq(cv4, vec4(wptype(1), wptype(2), wptype(3), wptype(4)))
1878
- wp.expect_eq(cv5, vec5(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1879
-
1880
- kernel = getkernel(check_vector_constants, suffix=dtype.__name__)
1881
-
1882
- if register_kernels:
1883
- return
1884
-
1885
- wp.launch(kernel, dim=1, inputs=[])
1886
-
1887
-
1888
- def test_minmax(test, device, dtype, register_kernels=False):
1889
- rng = np.random.default_rng(123)
1890
-
1891
- # \TODO: not quite sure why, but the numbers are off for 16 bit float
1892
- # on the cpu (but not cuda). This is probably just the sketchy float16
1893
- # arithmetic I implemented to get all this stuff working, so
1894
- # hopefully that can be fixed when we do that correctly.
1895
- tol = {
1896
- np.float16: 1.0e-2,
1897
- }.get(dtype, 0)
1898
-
1899
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1900
- vec2 = wp.types.vector(length=2, dtype=wptype)
1901
- vec3 = wp.types.vector(length=3, dtype=wptype)
1902
- vec4 = wp.types.vector(length=4, dtype=wptype)
1903
- vec5 = wp.types.vector(length=5, dtype=wptype)
1904
-
1905
- # \TODO: Also not quite sure why: this kernel compiles incredibly
1906
- # slowly though...
1907
- def check_vec_min_max(
1908
- a: wp.array(dtype=wptype, ndim=2),
1909
- b: wp.array(dtype=wptype, ndim=2),
1910
- mins: wp.array(dtype=wptype, ndim=2),
1911
- maxs: wp.array(dtype=wptype, ndim=2),
1912
- ):
1913
- for i in range(10):
1914
- # multiplying by 2 so we've got something to backpropagate:
1915
- a2read = vec2(a[i, 0], a[i, 1])
1916
- b2read = vec2(b[i, 0], b[i, 1])
1917
- c2 = wptype(2) * wp.min(a2read, b2read)
1918
- d2 = wptype(2) * wp.max(a2read, b2read)
1919
-
1920
- a3read = vec3(a[i, 2], a[i, 3], a[i, 4])
1921
- b3read = vec3(b[i, 2], b[i, 3], b[i, 4])
1922
- c3 = wptype(2) * wp.min(a3read, b3read)
1923
- d3 = wptype(2) * wp.max(a3read, b3read)
1924
-
1925
- a4read = vec4(a[i, 5], a[i, 6], a[i, 7], a[i, 8])
1926
- b4read = vec4(b[i, 5], b[i, 6], b[i, 7], b[i, 8])
1927
- c4 = wptype(2) * wp.min(a4read, b4read)
1928
- d4 = wptype(2) * wp.max(a4read, b4read)
1929
-
1930
- a5read = vec5(a[i, 9], a[i, 10], a[i, 11], a[i, 12], a[i, 13])
1931
- b5read = vec5(b[i, 9], b[i, 10], b[i, 11], b[i, 12], b[i, 13])
1932
- c5 = wptype(2) * wp.min(a5read, b5read)
1933
- d5 = wptype(2) * wp.max(a5read, b5read)
1934
-
1935
- mins[i, 0] = c2[0]
1936
- mins[i, 1] = c2[1]
1937
-
1938
- mins[i, 2] = c3[0]
1939
- mins[i, 3] = c3[1]
1940
- mins[i, 4] = c3[2]
1941
-
1942
- mins[i, 5] = c4[0]
1943
- mins[i, 6] = c4[1]
1944
- mins[i, 7] = c4[2]
1945
- mins[i, 8] = c4[3]
1946
-
1947
- mins[i, 9] = c5[0]
1948
- mins[i, 10] = c5[1]
1949
- mins[i, 11] = c5[2]
1950
- mins[i, 12] = c5[3]
1951
- mins[i, 13] = c5[4]
1952
-
1953
- maxs[i, 0] = d2[0]
1954
- maxs[i, 1] = d2[1]
1955
-
1956
- maxs[i, 2] = d3[0]
1957
- maxs[i, 3] = d3[1]
1958
- maxs[i, 4] = d3[2]
1959
-
1960
- maxs[i, 5] = d4[0]
1961
- maxs[i, 6] = d4[1]
1962
- maxs[i, 7] = d4[2]
1963
- maxs[i, 8] = d4[3]
1964
-
1965
- maxs[i, 9] = d5[0]
1966
- maxs[i, 10] = d5[1]
1967
- maxs[i, 11] = d5[2]
1968
- maxs[i, 12] = d5[3]
1969
- maxs[i, 13] = d5[4]
1970
-
1971
- kernel = getkernel(check_vec_min_max, suffix=dtype.__name__)
1972
- output_select_kernel = get_select_kernel2(wptype)
1973
-
1974
- if register_kernels:
1975
- return
1976
-
1977
- a = wp.array(randvals(rng, (10, 14), dtype), dtype=wptype, requires_grad=True, device=device)
1978
- b = wp.array(randvals(rng, (10, 14), dtype), dtype=wptype, requires_grad=True, device=device)
1979
-
1980
- mins = wp.zeros((10, 14), dtype=wptype, requires_grad=True, device=device)
1981
- maxs = wp.zeros((10, 14), dtype=wptype, requires_grad=True, device=device)
1982
-
1983
- tape = wp.Tape()
1984
- with tape:
1985
- wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
1986
-
1987
- assert_np_equal(mins.numpy(), 2 * np.minimum(a.numpy(), b.numpy()), tol=tol)
1988
- assert_np_equal(maxs.numpy(), 2 * np.maximum(a.numpy(), b.numpy()), tol=tol)
1989
-
1990
- out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1991
- if dtype in np_float_types:
1992
- for i in range(10):
1993
- for j in range(14):
1994
- tape = wp.Tape()
1995
- with tape:
1996
- wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
1997
- wp.launch(output_select_kernel, dim=1, inputs=[mins, i, j], outputs=[out], device=device)
1998
-
1999
- tape.backward(loss=out)
2000
- expected = np.zeros_like(a.numpy())
2001
- expected[i, j] = 2 if (a.numpy()[i, j] < b.numpy()[i, j]) else 0
2002
- assert_np_equal(tape.gradients[a].numpy(), expected, tol=tol)
2003
- expected[i, j] = 2 if (b.numpy()[i, j] < a.numpy()[i, j]) else 0
2004
- assert_np_equal(tape.gradients[b].numpy(), expected, tol=tol)
2005
- tape.zero()
2006
-
2007
- tape = wp.Tape()
2008
- with tape:
2009
- wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
2010
- wp.launch(output_select_kernel, dim=1, inputs=[maxs, i, j], outputs=[out], device=device)
2011
-
2012
- tape.backward(loss=out)
2013
- expected = np.zeros_like(a.numpy())
2014
- expected[i, j] = 2 if (a.numpy()[i, j] > b.numpy()[i, j]) else 0
2015
- assert_np_equal(tape.gradients[a].numpy(), expected, tol=tol)
2016
- expected[i, j] = 2 if (b.numpy()[i, j] > a.numpy()[i, j]) else 0
2017
- assert_np_equal(tape.gradients[b].numpy(), expected, tol=tol)
2018
- tape.zero()
2019
-
2020
-
2021
- devices = get_test_devices()
2022
-
2023
-
2024
- class TestVecScalarOps(unittest.TestCase):
2025
- pass
2026
-
2027
-
2028
- for dtype in np_scalar_types:
2029
- add_function_test(TestVecScalarOps, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
2030
- add_function_test(TestVecScalarOps, f"test_components_{dtype.__name__}", test_components, devices=None, dtype=dtype)
2031
- add_function_test(
2032
- TestVecScalarOps, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2033
- )
2034
- add_function_test_register_kernel(
2035
- TestVecScalarOps, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
2036
- )
2037
- add_function_test_register_kernel(
2038
- TestVecScalarOps,
2039
- f"test_anon_type_instance_{dtype.__name__}",
2040
- test_anon_type_instance,
2041
- devices=devices,
2042
- dtype=dtype,
2043
- )
2044
- add_function_test_register_kernel(
2045
- TestVecScalarOps, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2046
- )
2047
- add_function_test_register_kernel(
2048
- TestVecScalarOps, f"test_equality_{dtype.__name__}", test_equality, devices=devices, dtype=dtype
2049
- )
2050
- add_function_test_register_kernel(
2051
- TestVecScalarOps,
2052
- f"test_scalar_multiplication_{dtype.__name__}",
2053
- test_scalar_multiplication,
2054
- devices=devices,
2055
- dtype=dtype,
2056
- )
2057
- add_function_test_register_kernel(
2058
- TestVecScalarOps,
2059
- f"test_scalar_multiplication_rightmul_{dtype.__name__}",
2060
- test_scalar_multiplication_rightmul,
2061
- devices=devices,
2062
- dtype=dtype,
2063
- )
2064
- add_function_test_register_kernel(
2065
- TestVecScalarOps,
2066
- f"test_cw_multiplication_{dtype.__name__}",
2067
- test_cw_multiplication,
2068
- devices=devices,
2069
- dtype=dtype,
2070
- )
2071
- add_function_test_register_kernel(
2072
- TestVecScalarOps, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2073
- )
2074
- add_function_test_register_kernel(
2075
- TestVecScalarOps, f"test_cw_division_{dtype.__name__}", test_cw_division, devices=devices, dtype=dtype
2076
- )
2077
- add_function_test_register_kernel(
2078
- TestVecScalarOps, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2079
- )
2080
- add_function_test_register_kernel(
2081
- TestVecScalarOps, f"test_dotproduct_{dtype.__name__}", test_dotproduct, devices=devices, dtype=dtype
2082
- )
2083
- add_function_test_register_kernel(
2084
- TestVecScalarOps, f"test_equivalent_types_{dtype.__name__}", test_equivalent_types, devices=devices, dtype=dtype
2085
- )
2086
- add_function_test_register_kernel(
2087
- TestVecScalarOps, f"test_conversions_{dtype.__name__}", test_conversions, devices=devices, dtype=dtype
2088
- )
2089
- add_function_test_register_kernel(
2090
- TestVecScalarOps, f"test_constants_{dtype.__name__}", test_constants, devices=devices, dtype=dtype
2091
- )
2092
-
2093
- # the kernels in this test compile incredibly slowly...
2094
- # add_function_test_register_kernel(TestVecScalarOps, f"test_minmax_{dtype.__name__}", test_minmax, devices=devices, dtype=dtype)
2095
-
2096
-
2097
- if __name__ == "__main__":
2098
- wp.build.clear_kernel_cache()
2099
- unittest.main(verbosity=2, failfast=True)
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
10
+ import numpy as np
11
+
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
14
+
15
+ np_signed_int_types = [
16
+ np.int8,
17
+ np.int16,
18
+ np.int32,
19
+ np.int64,
20
+ np.byte,
21
+ ]
22
+
23
+ np_unsigned_int_types = [
24
+ np.uint8,
25
+ np.uint16,
26
+ np.uint32,
27
+ np.uint64,
28
+ np.ubyte,
29
+ ]
30
+
31
+ np_int_types = np_signed_int_types + np_unsigned_int_types
32
+
33
+ np_float_types = [np.float16, np.float32, np.float64]
34
+
35
+ np_scalar_types = np_int_types + np_float_types
36
+
37
+
38
+ def randvals(rng, shape, dtype):
39
+ if dtype in np_float_types:
40
+ return rng.standard_normal(size=shape).astype(dtype)
41
+ elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
42
+ return rng.integers(1, high=3, size=shape, dtype=dtype)
43
+ return rng.integers(1, high=5, size=shape, dtype=dtype)
44
+
45
+
46
+ kernel_cache = {}
47
+
48
+
49
+ def getkernel(func, suffix=""):
50
+ key = func.__name__ + "_" + suffix
51
+ if key not in kernel_cache:
52
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
53
+ return kernel_cache[key]
54
+
55
+
56
+ def get_select_kernel(dtype):
57
+ def output_select_kernel_fn(
58
+ input: wp.array(dtype=dtype),
59
+ index: int,
60
+ out: wp.array(dtype=dtype),
61
+ ):
62
+ out[0] = input[index]
63
+
64
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
65
+
66
+
67
+ def get_select_kernel2(dtype):
68
+ def output_select_kernel2_fn(
69
+ input: wp.array(dtype=dtype, ndim=2),
70
+ index0: int,
71
+ index1: int,
72
+ out: wp.array(dtype=dtype),
73
+ ):
74
+ out[0] = input[index0, index1]
75
+
76
+ return getkernel(output_select_kernel2_fn, suffix=dtype.__name__)
77
+
78
+
79
+ def test_arrays(test, device, dtype):
80
+ rng = np.random.default_rng(123)
81
+
82
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
83
+ vec2 = wp.types.vector(length=2, dtype=wptype)
84
+ vec3 = wp.types.vector(length=3, dtype=wptype)
85
+ vec4 = wp.types.vector(length=4, dtype=wptype)
86
+ vec5 = wp.types.vector(length=5, dtype=wptype)
87
+
88
+ v2_np = randvals(rng, (10, 2), dtype)
89
+ v3_np = randvals(rng, (10, 3), dtype)
90
+ v4_np = randvals(rng, (10, 4), dtype)
91
+ v5_np = randvals(rng, (10, 5), dtype)
92
+
93
+ v2 = wp.array(v2_np, dtype=vec2, requires_grad=True, device=device)
94
+ v3 = wp.array(v3_np, dtype=vec3, requires_grad=True, device=device)
95
+ v4 = wp.array(v4_np, dtype=vec4, requires_grad=True, device=device)
96
+ v5 = wp.array(v5_np, dtype=vec5, requires_grad=True, device=device)
97
+
98
+ assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
99
+ assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
100
+ assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
101
+ assert_np_equal(v5.numpy(), v5_np, tol=1.0e-6)
102
+
103
+ vec2 = wp.types.vector(length=2, dtype=wptype)
104
+ vec3 = wp.types.vector(length=3, dtype=wptype)
105
+ vec4 = wp.types.vector(length=4, dtype=wptype)
106
+
107
+ v2 = wp.array(v2_np, dtype=vec2, requires_grad=True, device=device)
108
+ v3 = wp.array(v3_np, dtype=vec3, requires_grad=True, device=device)
109
+ v4 = wp.array(v4_np, dtype=vec4, requires_grad=True, device=device)
110
+
111
+ assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
112
+ assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
113
+ assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
114
+
115
+
116
+ def test_components(test, device, dtype):
117
+ # test accessing vector components from Python - this is especially important
118
+ # for float16, which requires special handling internally
119
+
120
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
121
+ vec3 = wp.types.vector(length=3, dtype=wptype)
122
+
123
+ v = vec3(1, 2, 3)
124
+
125
+ # test __getitem__ for individual components
126
+ test.assertEqual(v[0], 1)
127
+ test.assertEqual(v[1], 2)
128
+ test.assertEqual(v[2], 3)
129
+
130
+ # test __getitem__ for slices
131
+ s = v[:]
132
+ test.assertEqual(s[0], 1)
133
+ test.assertEqual(s[1], 2)
134
+ test.assertEqual(s[2], 3)
135
+
136
+ s = v[1:]
137
+ test.assertEqual(s[0], 2)
138
+ test.assertEqual(s[1], 3)
139
+
140
+ s = v[:2]
141
+ test.assertEqual(s[0], 1)
142
+ test.assertEqual(s[1], 2)
143
+
144
+ s = v[::2]
145
+ test.assertEqual(s[0], 1)
146
+ test.assertEqual(s[1], 3)
147
+
148
+ # test __setitem__ for individual components
149
+ v[0] = 4
150
+ v[1] = 5
151
+ v[2] = 6
152
+ test.assertEqual(v[0], 4)
153
+ test.assertEqual(v[1], 5)
154
+ test.assertEqual(v[2], 6)
155
+
156
+ # test __setitem__ for slices
157
+ v[:] = [7, 8, 9]
158
+ test.assertEqual(v[0], 7)
159
+ test.assertEqual(v[1], 8)
160
+ test.assertEqual(v[2], 9)
161
+
162
+ v[1:] = [10, 11]
163
+ test.assertEqual(v[0], 7)
164
+ test.assertEqual(v[1], 10)
165
+ test.assertEqual(v[2], 11)
166
+
167
+ v[:2] = [12, 13]
168
+ test.assertEqual(v[0], 12)
169
+ test.assertEqual(v[1], 13)
170
+ test.assertEqual(v[2], 11)
171
+
172
+ v[::2] = [14, 15]
173
+ test.assertEqual(v[0], 14)
174
+ test.assertEqual(v[1], 13)
175
+ test.assertEqual(v[2], 15)
176
+
177
+
178
+ def test_py_arithmetic_ops(test, device, dtype):
179
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
180
+
181
+ def make_vec(*args):
182
+ if wptype in wp.types.int_types:
183
+ # Cast to the correct integer type to simulate wrapping.
184
+ return tuple(wptype._type_(x).value for x in args)
185
+
186
+ return args
187
+
188
+ vec_cls = wp.vec(3, wptype)
189
+
190
+ v = vec_cls(1, -2, 3)
191
+ test.assertSequenceEqual(+v, make_vec(1, -2, 3))
192
+ test.assertSequenceEqual(-v, make_vec(-1, 2, -3))
193
+ test.assertSequenceEqual(v + vec_cls(5, 5, 5), make_vec(6, 3, 8))
194
+ test.assertSequenceEqual(v - vec_cls(5, 5, 5), make_vec(-4, -7, -2))
195
+
196
+ v = vec_cls(2, 4, 6)
197
+ test.assertSequenceEqual(v * wptype(2), make_vec(4, 8, 12))
198
+ test.assertSequenceEqual(wptype(2) * v, make_vec(4, 8, 12))
199
+ test.assertSequenceEqual(v / wptype(2), make_vec(1, 2, 3))
200
+ test.assertSequenceEqual(wptype(24) / v, make_vec(12, 6, 4))
201
+
202
+
203
+ def test_constructors(test, device, dtype, register_kernels=False):
204
+ rng = np.random.default_rng(123)
205
+
206
+ tol = {
207
+ np.float16: 5.0e-3,
208
+ np.float32: 1.0e-6,
209
+ np.float64: 1.0e-8,
210
+ }.get(dtype, 0)
211
+
212
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
213
+ vec2 = wp.types.vector(length=2, dtype=wptype)
214
+ vec3 = wp.types.vector(length=3, dtype=wptype)
215
+ vec4 = wp.types.vector(length=4, dtype=wptype)
216
+ vec5 = wp.types.vector(length=5, dtype=wptype)
217
+
218
+ def check_scalar_constructor(
219
+ input: wp.array(dtype=wptype),
220
+ v2: wp.array(dtype=vec2),
221
+ v3: wp.array(dtype=vec3),
222
+ v4: wp.array(dtype=vec4),
223
+ v5: wp.array(dtype=vec5),
224
+ v20: wp.array(dtype=wptype),
225
+ v21: wp.array(dtype=wptype),
226
+ v30: wp.array(dtype=wptype),
227
+ v31: wp.array(dtype=wptype),
228
+ v32: wp.array(dtype=wptype),
229
+ v40: wp.array(dtype=wptype),
230
+ v41: wp.array(dtype=wptype),
231
+ v42: wp.array(dtype=wptype),
232
+ v43: wp.array(dtype=wptype),
233
+ v50: wp.array(dtype=wptype),
234
+ v51: wp.array(dtype=wptype),
235
+ v52: wp.array(dtype=wptype),
236
+ v53: wp.array(dtype=wptype),
237
+ v54: wp.array(dtype=wptype),
238
+ ):
239
+ v2result = vec2(input[0])
240
+ v3result = vec3(input[0])
241
+ v4result = vec4(input[0])
242
+ v5result = vec5(input[0])
243
+
244
+ v2[0] = v2result
245
+ v3[0] = v3result
246
+ v4[0] = v4result
247
+ v5[0] = v5result
248
+
249
+ # multiply outputs by 2 so we've got something to backpropagate
250
+ v20[0] = wptype(2) * v2result[0]
251
+ v21[0] = wptype(2) * v2result[1]
252
+
253
+ v30[0] = wptype(2) * v3result[0]
254
+ v31[0] = wptype(2) * v3result[1]
255
+ v32[0] = wptype(2) * v3result[2]
256
+
257
+ v40[0] = wptype(2) * v4result[0]
258
+ v41[0] = wptype(2) * v4result[1]
259
+ v42[0] = wptype(2) * v4result[2]
260
+ v43[0] = wptype(2) * v4result[3]
261
+
262
+ v50[0] = wptype(2) * v5result[0]
263
+ v51[0] = wptype(2) * v5result[1]
264
+ v52[0] = wptype(2) * v5result[2]
265
+ v53[0] = wptype(2) * v5result[3]
266
+ v54[0] = wptype(2) * v5result[4]
267
+
268
+ def check_vector_constructors(
269
+ input: wp.array(dtype=wptype),
270
+ v2: wp.array(dtype=vec2),
271
+ v3: wp.array(dtype=vec3),
272
+ v4: wp.array(dtype=vec4),
273
+ v5: wp.array(dtype=vec5),
274
+ v20: wp.array(dtype=wptype),
275
+ v21: wp.array(dtype=wptype),
276
+ v30: wp.array(dtype=wptype),
277
+ v31: wp.array(dtype=wptype),
278
+ v32: wp.array(dtype=wptype),
279
+ v40: wp.array(dtype=wptype),
280
+ v41: wp.array(dtype=wptype),
281
+ v42: wp.array(dtype=wptype),
282
+ v43: wp.array(dtype=wptype),
283
+ v50: wp.array(dtype=wptype),
284
+ v51: wp.array(dtype=wptype),
285
+ v52: wp.array(dtype=wptype),
286
+ v53: wp.array(dtype=wptype),
287
+ v54: wp.array(dtype=wptype),
288
+ ):
289
+ v2result = vec2(input[0], input[1])
290
+ v3result = vec3(input[2], input[3], input[4])
291
+ v4result = vec4(input[5], input[6], input[7], input[8])
292
+ v5result = vec5(input[9], input[10], input[11], input[12], input[13])
293
+
294
+ v2[0] = v2result
295
+ v3[0] = v3result
296
+ v4[0] = v4result
297
+ v5[0] = v5result
298
+
299
+ # multiply the output by 2 so we've got something to backpropagate:
300
+ v20[0] = wptype(2) * v2result[0]
301
+ v21[0] = wptype(2) * v2result[1]
302
+
303
+ v30[0] = wptype(2) * v3result[0]
304
+ v31[0] = wptype(2) * v3result[1]
305
+ v32[0] = wptype(2) * v3result[2]
306
+
307
+ v40[0] = wptype(2) * v4result[0]
308
+ v41[0] = wptype(2) * v4result[1]
309
+ v42[0] = wptype(2) * v4result[2]
310
+ v43[0] = wptype(2) * v4result[3]
311
+
312
+ v50[0] = wptype(2) * v5result[0]
313
+ v51[0] = wptype(2) * v5result[1]
314
+ v52[0] = wptype(2) * v5result[2]
315
+ v53[0] = wptype(2) * v5result[3]
316
+ v54[0] = wptype(2) * v5result[4]
317
+
318
+ vec_kernel = getkernel(check_vector_constructors, suffix=dtype.__name__)
319
+ kernel = getkernel(check_scalar_constructor, suffix=dtype.__name__)
320
+
321
+ if register_kernels:
322
+ return
323
+
324
+ input = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
325
+ v2 = wp.zeros(1, dtype=vec2, device=device)
326
+ v3 = wp.zeros(1, dtype=vec3, device=device)
327
+ v4 = wp.zeros(1, dtype=vec4, device=device)
328
+ v5 = wp.zeros(1, dtype=vec5, device=device)
329
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
330
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
331
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
332
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
333
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
334
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
335
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
336
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
337
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
338
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
339
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
340
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
341
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
342
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
343
+
344
+ tape = wp.Tape()
345
+ with tape:
346
+ wp.launch(
347
+ kernel,
348
+ dim=1,
349
+ inputs=[input],
350
+ outputs=[v2, v3, v4, v5, v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
351
+ device=device,
352
+ )
353
+
354
+ if dtype in np_float_types:
355
+ for l in [v20, v21]:
356
+ tape.backward(loss=l)
357
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
358
+ tape.zero()
359
+
360
+ for l in [v30, v31, v32]:
361
+ tape.backward(loss=l)
362
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
363
+ tape.zero()
364
+
365
+ for l in [v40, v41, v42, v43]:
366
+ tape.backward(loss=l)
367
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
368
+ tape.zero()
369
+
370
+ for l in [v50, v51, v52, v53, v54]:
371
+ tape.backward(loss=l)
372
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
373
+ tape.zero()
374
+
375
+ val = input.numpy()[0]
376
+ assert_np_equal(v2.numpy()[0], np.array([val, val]), tol=1.0e-6)
377
+ assert_np_equal(v3.numpy()[0], np.array([val, val, val]), tol=1.0e-6)
378
+ assert_np_equal(v4.numpy()[0], np.array([val, val, val, val]), tol=1.0e-6)
379
+ assert_np_equal(v5.numpy()[0], np.array([val, val, val, val, val]), tol=1.0e-6)
380
+
381
+ assert_np_equal(v20.numpy()[0], 2 * val, tol=1.0e-6)
382
+ assert_np_equal(v21.numpy()[0], 2 * val, tol=1.0e-6)
383
+ assert_np_equal(v30.numpy()[0], 2 * val, tol=1.0e-6)
384
+ assert_np_equal(v31.numpy()[0], 2 * val, tol=1.0e-6)
385
+ assert_np_equal(v32.numpy()[0], 2 * val, tol=1.0e-6)
386
+ assert_np_equal(v40.numpy()[0], 2 * val, tol=1.0e-6)
387
+ assert_np_equal(v41.numpy()[0], 2 * val, tol=1.0e-6)
388
+ assert_np_equal(v42.numpy()[0], 2 * val, tol=1.0e-6)
389
+ assert_np_equal(v43.numpy()[0], 2 * val, tol=1.0e-6)
390
+ assert_np_equal(v50.numpy()[0], 2 * val, tol=1.0e-6)
391
+ assert_np_equal(v51.numpy()[0], 2 * val, tol=1.0e-6)
392
+ assert_np_equal(v52.numpy()[0], 2 * val, tol=1.0e-6)
393
+ assert_np_equal(v53.numpy()[0], 2 * val, tol=1.0e-6)
394
+ assert_np_equal(v54.numpy()[0], 2 * val, tol=1.0e-6)
395
+
396
+ input = wp.array(randvals(rng, [14], dtype), requires_grad=True, device=device)
397
+ tape = wp.Tape()
398
+ with tape:
399
+ wp.launch(
400
+ vec_kernel,
401
+ dim=1,
402
+ inputs=[input],
403
+ outputs=[v2, v3, v4, v5, v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
404
+ device=device,
405
+ )
406
+
407
+ if dtype in np_float_types:
408
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
409
+ tape.backward(loss=l)
410
+ grad = tape.gradients[input].numpy()
411
+ expected_grad = np.zeros_like(grad)
412
+ expected_grad[i] = 2
413
+ assert_np_equal(grad, expected_grad, tol=tol)
414
+ tape.zero()
415
+
416
+ assert_np_equal(v2.numpy()[0, 0], input.numpy()[0], tol=tol)
417
+ assert_np_equal(v2.numpy()[0, 1], input.numpy()[1], tol=tol)
418
+ assert_np_equal(v3.numpy()[0, 0], input.numpy()[2], tol=tol)
419
+ assert_np_equal(v3.numpy()[0, 1], input.numpy()[3], tol=tol)
420
+ assert_np_equal(v3.numpy()[0, 2], input.numpy()[4], tol=tol)
421
+ assert_np_equal(v4.numpy()[0, 0], input.numpy()[5], tol=tol)
422
+ assert_np_equal(v4.numpy()[0, 1], input.numpy()[6], tol=tol)
423
+ assert_np_equal(v4.numpy()[0, 2], input.numpy()[7], tol=tol)
424
+ assert_np_equal(v4.numpy()[0, 3], input.numpy()[8], tol=tol)
425
+ assert_np_equal(v5.numpy()[0, 0], input.numpy()[9], tol=tol)
426
+ assert_np_equal(v5.numpy()[0, 1], input.numpy()[10], tol=tol)
427
+ assert_np_equal(v5.numpy()[0, 2], input.numpy()[11], tol=tol)
428
+ assert_np_equal(v5.numpy()[0, 3], input.numpy()[12], tol=tol)
429
+ assert_np_equal(v5.numpy()[0, 4], input.numpy()[13], tol=tol)
430
+
431
+ assert_np_equal(v20.numpy()[0], 2 * input.numpy()[0], tol=tol)
432
+ assert_np_equal(v21.numpy()[0], 2 * input.numpy()[1], tol=tol)
433
+ assert_np_equal(v30.numpy()[0], 2 * input.numpy()[2], tol=tol)
434
+ assert_np_equal(v31.numpy()[0], 2 * input.numpy()[3], tol=tol)
435
+ assert_np_equal(v32.numpy()[0], 2 * input.numpy()[4], tol=tol)
436
+ assert_np_equal(v40.numpy()[0], 2 * input.numpy()[5], tol=tol)
437
+ assert_np_equal(v41.numpy()[0], 2 * input.numpy()[6], tol=tol)
438
+ assert_np_equal(v42.numpy()[0], 2 * input.numpy()[7], tol=tol)
439
+ assert_np_equal(v43.numpy()[0], 2 * input.numpy()[8], tol=tol)
440
+ assert_np_equal(v50.numpy()[0], 2 * input.numpy()[9], tol=tol)
441
+ assert_np_equal(v51.numpy()[0], 2 * input.numpy()[10], tol=tol)
442
+ assert_np_equal(v52.numpy()[0], 2 * input.numpy()[11], tol=tol)
443
+ assert_np_equal(v53.numpy()[0], 2 * input.numpy()[12], tol=tol)
444
+ assert_np_equal(v54.numpy()[0], 2 * input.numpy()[13], tol=tol)
445
+
446
+
447
+ def test_anon_type_instance(test, device, dtype, register_kernels=False):
448
+ rng = np.random.default_rng(123)
449
+
450
+ tol = {
451
+ np.float16: 5.0e-3,
452
+ np.float32: 1.0e-6,
453
+ np.float64: 1.0e-8,
454
+ }.get(dtype, 0)
455
+
456
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
457
+
458
+ def check_scalar_init(
459
+ input: wp.array(dtype=wptype),
460
+ output: wp.array(dtype=wptype),
461
+ ):
462
+ v2result = wp.vector(input[0], length=2)
463
+ v3result = wp.vector(input[1], length=3)
464
+ v4result = wp.vector(input[2], length=4)
465
+ v5result = wp.vector(input[3], length=5)
466
+
467
+ idx = 0
468
+ for i in range(2):
469
+ output[idx] = wptype(2) * v2result[i]
470
+ idx = idx + 1
471
+ for i in range(3):
472
+ output[idx] = wptype(2) * v3result[i]
473
+ idx = idx + 1
474
+ for i in range(4):
475
+ output[idx] = wptype(2) * v4result[i]
476
+ idx = idx + 1
477
+ for i in range(5):
478
+ output[idx] = wptype(2) * v5result[i]
479
+ idx = idx + 1
480
+
481
+ def check_component_init(
482
+ input: wp.array(dtype=wptype),
483
+ output: wp.array(dtype=wptype),
484
+ ):
485
+ v2result = wp.vector(input[0], input[1])
486
+ v3result = wp.vector(input[2], input[3], input[4])
487
+ v4result = wp.vector(input[5], input[6], input[7], input[8])
488
+ v5result = wp.vector(input[9], input[10], input[11], input[12], input[13])
489
+
490
+ idx = 0
491
+ for i in range(2):
492
+ output[idx] = wptype(2) * v2result[i]
493
+ idx = idx + 1
494
+ for i in range(3):
495
+ output[idx] = wptype(2) * v3result[i]
496
+ idx = idx + 1
497
+ for i in range(4):
498
+ output[idx] = wptype(2) * v4result[i]
499
+ idx = idx + 1
500
+ for i in range(5):
501
+ output[idx] = wptype(2) * v5result[i]
502
+ idx = idx + 1
503
+
504
+ scalar_kernel = getkernel(check_scalar_init, suffix=dtype.__name__)
505
+ component_kernel = getkernel(check_component_init, suffix=dtype.__name__)
506
+ output_select_kernel = get_select_kernel(wptype)
507
+
508
+ if register_kernels:
509
+ return
510
+
511
+ input = wp.array(randvals(rng, [4], dtype), requires_grad=True, device=device)
512
+ output = wp.zeros(2 + 3 + 4 + 5, dtype=wptype, requires_grad=True, device=device)
513
+
514
+ wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
515
+
516
+ assert_np_equal(output.numpy()[:2], 2 * np.array([input.numpy()[0]] * 2), tol=1.0e-6)
517
+ assert_np_equal(output.numpy()[2:5], 2 * np.array([input.numpy()[1]] * 3), tol=1.0e-6)
518
+ assert_np_equal(output.numpy()[5:9], 2 * np.array([input.numpy()[2]] * 4), tol=1.0e-6)
519
+ assert_np_equal(output.numpy()[9:], 2 * np.array([input.numpy()[3]] * 5), tol=1.0e-6)
520
+
521
+ if dtype in np_float_types:
522
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
523
+ for i in range(len(output)):
524
+ tape = wp.Tape()
525
+ with tape:
526
+ wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
527
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
528
+
529
+ tape.backward(loss=out)
530
+ expected = np.zeros_like(input.numpy())
531
+ if i < 2:
532
+ expected[0] = 2
533
+ elif i < 5:
534
+ expected[1] = 2
535
+ elif i < 9:
536
+ expected[2] = 2
537
+ else:
538
+ expected[3] = 2
539
+
540
+ assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
541
+
542
+ tape.reset()
543
+ tape.zero()
544
+
545
+ input = wp.array(randvals(rng, [2 + 3 + 4 + 5], dtype), requires_grad=True, device=device)
546
+ output = wp.zeros(2 + 3 + 4 + 5, dtype=wptype, requires_grad=True, device=device)
547
+
548
+ wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
549
+
550
+ assert_np_equal(output.numpy(), 2 * input.numpy(), tol=1.0e-6)
551
+
552
+ if dtype in np_float_types:
553
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
554
+ for i in range(len(output)):
555
+ tape = wp.Tape()
556
+ with tape:
557
+ wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
558
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
559
+
560
+ tape.backward(loss=out)
561
+ expected = np.zeros_like(input.numpy())
562
+ expected[i] = 2
563
+
564
+ assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
565
+
566
+ tape.reset()
567
+ tape.zero()
568
+
569
+
570
+ def test_indexing(test, device, dtype, register_kernels=False):
571
+ rng = np.random.default_rng(123)
572
+
573
+ tol = {
574
+ np.float16: 5.0e-3,
575
+ np.float32: 1.0e-6,
576
+ np.float64: 1.0e-8,
577
+ }.get(dtype, 0)
578
+
579
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
580
+ vec2 = wp.types.vector(length=2, dtype=wptype)
581
+ vec3 = wp.types.vector(length=3, dtype=wptype)
582
+ vec4 = wp.types.vector(length=4, dtype=wptype)
583
+ vec5 = wp.types.vector(length=5, dtype=wptype)
584
+
585
+ def check_indexing(
586
+ v2: wp.array(dtype=vec2),
587
+ v3: wp.array(dtype=vec3),
588
+ v4: wp.array(dtype=vec4),
589
+ v5: wp.array(dtype=vec5),
590
+ v20: wp.array(dtype=wptype),
591
+ v21: wp.array(dtype=wptype),
592
+ v30: wp.array(dtype=wptype),
593
+ v31: wp.array(dtype=wptype),
594
+ v32: wp.array(dtype=wptype),
595
+ v40: wp.array(dtype=wptype),
596
+ v41: wp.array(dtype=wptype),
597
+ v42: wp.array(dtype=wptype),
598
+ v43: wp.array(dtype=wptype),
599
+ v50: wp.array(dtype=wptype),
600
+ v51: wp.array(dtype=wptype),
601
+ v52: wp.array(dtype=wptype),
602
+ v53: wp.array(dtype=wptype),
603
+ v54: wp.array(dtype=wptype),
604
+ ):
605
+ # multiply outputs by 2 so we've got something to backpropagate:
606
+ v20[0] = wptype(2) * v2[0][0]
607
+ v21[0] = wptype(2) * v2[0][1]
608
+
609
+ v30[0] = wptype(2) * v3[0][0]
610
+ v31[0] = wptype(2) * v3[0][1]
611
+ v32[0] = wptype(2) * v3[0][2]
612
+
613
+ v40[0] = wptype(2) * v4[0][0]
614
+ v41[0] = wptype(2) * v4[0][1]
615
+ v42[0] = wptype(2) * v4[0][2]
616
+ v43[0] = wptype(2) * v4[0][3]
617
+
618
+ v50[0] = wptype(2) * v5[0][0]
619
+ v51[0] = wptype(2) * v5[0][1]
620
+ v52[0] = wptype(2) * v5[0][2]
621
+ v53[0] = wptype(2) * v5[0][3]
622
+ v54[0] = wptype(2) * v5[0][4]
623
+
624
+ kernel = getkernel(check_indexing, suffix=dtype.__name__)
625
+
626
+ if register_kernels:
627
+ return
628
+
629
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
630
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
631
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
632
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
633
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
634
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
635
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
636
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
637
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
638
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
639
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
640
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
641
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
642
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
643
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
644
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
645
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
646
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
647
+
648
+ tape = wp.Tape()
649
+ with tape:
650
+ wp.launch(
651
+ kernel,
652
+ dim=1,
653
+ inputs=[v2, v3, v4, v5],
654
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
655
+ device=device,
656
+ )
657
+
658
+ if dtype in np_float_types:
659
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
660
+ tape.backward(loss=l)
661
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
662
+ expected_grads = np.zeros_like(allgrads)
663
+ expected_grads[i] = 2
664
+ assert_np_equal(allgrads, expected_grads, tol=tol)
665
+ tape.zero()
666
+
667
+ assert_np_equal(v20.numpy()[0], 2.0 * v2.numpy()[0, 0], tol=tol)
668
+ assert_np_equal(v21.numpy()[0], 2.0 * v2.numpy()[0, 1], tol=tol)
669
+ assert_np_equal(v30.numpy()[0], 2.0 * v3.numpy()[0, 0], tol=tol)
670
+ assert_np_equal(v31.numpy()[0], 2.0 * v3.numpy()[0, 1], tol=tol)
671
+ assert_np_equal(v32.numpy()[0], 2.0 * v3.numpy()[0, 2], tol=tol)
672
+ assert_np_equal(v40.numpy()[0], 2.0 * v4.numpy()[0, 0], tol=tol)
673
+ assert_np_equal(v41.numpy()[0], 2.0 * v4.numpy()[0, 1], tol=tol)
674
+ assert_np_equal(v42.numpy()[0], 2.0 * v4.numpy()[0, 2], tol=tol)
675
+ assert_np_equal(v43.numpy()[0], 2.0 * v4.numpy()[0, 3], tol=tol)
676
+ assert_np_equal(v50.numpy()[0], 2.0 * v5.numpy()[0, 0], tol=tol)
677
+ assert_np_equal(v51.numpy()[0], 2.0 * v5.numpy()[0, 1], tol=tol)
678
+ assert_np_equal(v52.numpy()[0], 2.0 * v5.numpy()[0, 2], tol=tol)
679
+ assert_np_equal(v53.numpy()[0], 2.0 * v5.numpy()[0, 3], tol=tol)
680
+ assert_np_equal(v54.numpy()[0], 2.0 * v5.numpy()[0, 4], tol=tol)
681
+
682
+
683
+ def test_equality(test, device, dtype, register_kernels=False):
684
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
685
+ vec2 = wp.types.vector(length=2, dtype=wptype)
686
+ vec3 = wp.types.vector(length=3, dtype=wptype)
687
+ vec4 = wp.types.vector(length=4, dtype=wptype)
688
+ vec5 = wp.types.vector(length=5, dtype=wptype)
689
+
690
+ def check_equality(
691
+ v20: wp.array(dtype=vec2),
692
+ v21: wp.array(dtype=vec2),
693
+ v22: wp.array(dtype=vec2),
694
+ v30: wp.array(dtype=vec3),
695
+ v31: wp.array(dtype=vec3),
696
+ v32: wp.array(dtype=vec3),
697
+ v33: wp.array(dtype=vec3),
698
+ v40: wp.array(dtype=vec4),
699
+ v41: wp.array(dtype=vec4),
700
+ v42: wp.array(dtype=vec4),
701
+ v43: wp.array(dtype=vec4),
702
+ v44: wp.array(dtype=vec4),
703
+ v50: wp.array(dtype=vec5),
704
+ v51: wp.array(dtype=vec5),
705
+ v52: wp.array(dtype=vec5),
706
+ v53: wp.array(dtype=vec5),
707
+ v54: wp.array(dtype=vec5),
708
+ v55: wp.array(dtype=vec5),
709
+ ):
710
+ wp.expect_eq(v20[0], v20[0])
711
+ wp.expect_neq(v21[0], v20[0])
712
+ wp.expect_neq(v22[0], v20[0])
713
+
714
+ wp.expect_eq(v30[0], v30[0])
715
+ wp.expect_neq(v31[0], v30[0])
716
+ wp.expect_neq(v32[0], v30[0])
717
+ wp.expect_neq(v33[0], v30[0])
718
+
719
+ wp.expect_eq(v40[0], v40[0])
720
+ wp.expect_neq(v41[0], v40[0])
721
+ wp.expect_neq(v42[0], v40[0])
722
+ wp.expect_neq(v43[0], v40[0])
723
+ wp.expect_neq(v44[0], v40[0])
724
+
725
+ wp.expect_eq(v50[0], v50[0])
726
+ wp.expect_neq(v51[0], v50[0])
727
+ wp.expect_neq(v52[0], v50[0])
728
+ wp.expect_neq(v53[0], v50[0])
729
+ wp.expect_neq(v54[0], v50[0])
730
+ wp.expect_neq(v55[0], v50[0])
731
+
732
+ kernel = getkernel(check_equality, suffix=dtype.__name__)
733
+
734
+ if register_kernels:
735
+ return
736
+
737
+ v20 = wp.array([1.0, 2.0], dtype=vec2, requires_grad=True, device=device)
738
+ v21 = wp.array([1.0, 3.0], dtype=vec2, requires_grad=True, device=device)
739
+ v22 = wp.array([3.0, 2.0], dtype=vec2, requires_grad=True, device=device)
740
+
741
+ v30 = wp.array([1.0, 2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
742
+ v31 = wp.array([-1.0, 2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
743
+ v32 = wp.array([1.0, -2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
744
+ v33 = wp.array([1.0, 2.0, -3.0], dtype=vec3, requires_grad=True, device=device)
745
+
746
+ v40 = wp.array([1.0, 2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
747
+ v41 = wp.array([-1.0, 2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
748
+ v42 = wp.array([1.0, -2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
749
+ v43 = wp.array([1.0, 2.0, -3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
750
+ v44 = wp.array([1.0, 2.0, 3.0, -4.0], dtype=vec4, requires_grad=True, device=device)
751
+
752
+ v50 = wp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
753
+ v51 = wp.array([-1.0, 2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
754
+ v52 = wp.array([1.0, -2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
755
+ v53 = wp.array([1.0, 2.0, -3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
756
+ v54 = wp.array([1.0, 2.0, 3.0, -4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
757
+ v55 = wp.array([1.0, 2.0, 3.0, 4.0, -5.0], dtype=vec5, requires_grad=True, device=device)
758
+ wp.launch(
759
+ kernel,
760
+ dim=1,
761
+ inputs=[
762
+ v20,
763
+ v21,
764
+ v22,
765
+ v30,
766
+ v31,
767
+ v32,
768
+ v33,
769
+ v40,
770
+ v41,
771
+ v42,
772
+ v43,
773
+ v44,
774
+ v50,
775
+ v51,
776
+ v52,
777
+ v53,
778
+ v54,
779
+ v55,
780
+ ],
781
+ outputs=[],
782
+ device=device,
783
+ )
784
+
785
+
786
+ def test_scalar_multiplication(test, device, dtype, register_kernels=False):
787
+ rng = np.random.default_rng(123)
788
+
789
+ tol = {
790
+ np.float16: 5.0e-3,
791
+ np.float32: 1.0e-6,
792
+ np.float64: 1.0e-8,
793
+ }.get(dtype, 0)
794
+
795
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
796
+ vec2 = wp.types.vector(length=2, dtype=wptype)
797
+ vec3 = wp.types.vector(length=3, dtype=wptype)
798
+ vec4 = wp.types.vector(length=4, dtype=wptype)
799
+ vec5 = wp.types.vector(length=5, dtype=wptype)
800
+
801
+ def check_mul(
802
+ s: wp.array(dtype=wptype),
803
+ v2: wp.array(dtype=vec2),
804
+ v3: wp.array(dtype=vec3),
805
+ v4: wp.array(dtype=vec4),
806
+ v5: wp.array(dtype=vec5),
807
+ v20: wp.array(dtype=wptype),
808
+ v21: wp.array(dtype=wptype),
809
+ v30: wp.array(dtype=wptype),
810
+ v31: wp.array(dtype=wptype),
811
+ v32: wp.array(dtype=wptype),
812
+ v40: wp.array(dtype=wptype),
813
+ v41: wp.array(dtype=wptype),
814
+ v42: wp.array(dtype=wptype),
815
+ v43: wp.array(dtype=wptype),
816
+ v50: wp.array(dtype=wptype),
817
+ v51: wp.array(dtype=wptype),
818
+ v52: wp.array(dtype=wptype),
819
+ v53: wp.array(dtype=wptype),
820
+ v54: wp.array(dtype=wptype),
821
+ ):
822
+ v2result = s[0] * v2[0]
823
+ v3result = s[0] * v3[0]
824
+ v4result = s[0] * v4[0]
825
+ v5result = s[0] * v5[0]
826
+
827
+ # multiply outputs by 2 so we've got something to backpropagate:
828
+ v20[0] = wptype(2) * v2result[0]
829
+ v21[0] = wptype(2) * v2result[1]
830
+
831
+ v30[0] = wptype(2) * v3result[0]
832
+ v31[0] = wptype(2) * v3result[1]
833
+ v32[0] = wptype(2) * v3result[2]
834
+
835
+ v40[0] = wptype(2) * v4result[0]
836
+ v41[0] = wptype(2) * v4result[1]
837
+ v42[0] = wptype(2) * v4result[2]
838
+ v43[0] = wptype(2) * v4result[3]
839
+
840
+ v50[0] = wptype(2) * v5result[0]
841
+ v51[0] = wptype(2) * v5result[1]
842
+ v52[0] = wptype(2) * v5result[2]
843
+ v53[0] = wptype(2) * v5result[3]
844
+ v54[0] = wptype(2) * v5result[4]
845
+
846
+ kernel = getkernel(check_mul, suffix=dtype.__name__)
847
+
848
+ if register_kernels:
849
+ return
850
+
851
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
852
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
853
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
854
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
855
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
856
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
857
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
858
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
859
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
860
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
861
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
862
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
863
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
864
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
865
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
866
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
867
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
868
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
869
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
870
+ tape = wp.Tape()
871
+ with tape:
872
+ wp.launch(
873
+ kernel,
874
+ dim=1,
875
+ inputs=[
876
+ s,
877
+ v2,
878
+ v3,
879
+ v4,
880
+ v5,
881
+ ],
882
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
883
+ device=device,
884
+ )
885
+
886
+ assert_np_equal(v20.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 0], tol=tol)
887
+ assert_np_equal(v21.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 1], tol=tol)
888
+
889
+ assert_np_equal(v30.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 0], tol=10 * tol)
890
+ assert_np_equal(v31.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 1], tol=10 * tol)
891
+ assert_np_equal(v32.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 2], tol=10 * tol)
892
+
893
+ assert_np_equal(v40.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 0], tol=10 * tol)
894
+ assert_np_equal(v41.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 1], tol=10 * tol)
895
+ assert_np_equal(v42.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 2], tol=10 * tol)
896
+ assert_np_equal(v43.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 3], tol=10 * tol)
897
+
898
+ assert_np_equal(v50.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 0], tol=10 * tol)
899
+ assert_np_equal(v51.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 1], tol=10 * tol)
900
+ assert_np_equal(v52.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 2], tol=10 * tol)
901
+ assert_np_equal(v53.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 3], tol=10 * tol)
902
+ assert_np_equal(v54.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 4], tol=10 * tol)
903
+
904
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
905
+
906
+ if dtype in np_float_types:
907
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43]):
908
+ tape.backward(loss=l)
909
+ sgrad = tape.gradients[s].numpy()[0]
910
+ assert_np_equal(sgrad, 2 * incmps[i], tol=10 * tol)
911
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4]])
912
+ expected_grads = np.zeros_like(allgrads)
913
+ expected_grads[i] = s.numpy()[0] * 2
914
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
915
+ tape.zero()
916
+
917
+
918
+ def test_scalar_multiplication_rightmul(test, device, dtype, register_kernels=False):
919
+ rng = np.random.default_rng(123)
920
+
921
+ tol = {
922
+ np.float16: 5.0e-3,
923
+ np.float32: 1.0e-6,
924
+ np.float64: 1.0e-8,
925
+ }.get(dtype, 0)
926
+
927
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
928
+ vec2 = wp.types.vector(length=2, dtype=wptype)
929
+ vec3 = wp.types.vector(length=3, dtype=wptype)
930
+ vec4 = wp.types.vector(length=4, dtype=wptype)
931
+ vec5 = wp.types.vector(length=5, dtype=wptype)
932
+
933
+ def check_rightmul(
934
+ s: wp.array(dtype=wptype),
935
+ v2: wp.array(dtype=vec2),
936
+ v3: wp.array(dtype=vec3),
937
+ v4: wp.array(dtype=vec4),
938
+ v5: wp.array(dtype=vec5),
939
+ v20: wp.array(dtype=wptype),
940
+ v21: wp.array(dtype=wptype),
941
+ v30: wp.array(dtype=wptype),
942
+ v31: wp.array(dtype=wptype),
943
+ v32: wp.array(dtype=wptype),
944
+ v40: wp.array(dtype=wptype),
945
+ v41: wp.array(dtype=wptype),
946
+ v42: wp.array(dtype=wptype),
947
+ v43: wp.array(dtype=wptype),
948
+ v50: wp.array(dtype=wptype),
949
+ v51: wp.array(dtype=wptype),
950
+ v52: wp.array(dtype=wptype),
951
+ v53: wp.array(dtype=wptype),
952
+ v54: wp.array(dtype=wptype),
953
+ ):
954
+ v2result = v2[0] * s[0]
955
+ v3result = v3[0] * s[0]
956
+ v4result = v4[0] * s[0]
957
+ v5result = v5[0] * s[0]
958
+
959
+ # multiply outputs by 2 so we've got something to backpropagate:
960
+ v20[0] = wptype(2) * v2result[0]
961
+ v21[0] = wptype(2) * v2result[1]
962
+
963
+ v30[0] = wptype(2) * v3result[0]
964
+ v31[0] = wptype(2) * v3result[1]
965
+ v32[0] = wptype(2) * v3result[2]
966
+
967
+ v40[0] = wptype(2) * v4result[0]
968
+ v41[0] = wptype(2) * v4result[1]
969
+ v42[0] = wptype(2) * v4result[2]
970
+ v43[0] = wptype(2) * v4result[3]
971
+
972
+ v50[0] = wptype(2) * v5result[0]
973
+ v51[0] = wptype(2) * v5result[1]
974
+ v52[0] = wptype(2) * v5result[2]
975
+ v53[0] = wptype(2) * v5result[3]
976
+ v54[0] = wptype(2) * v5result[4]
977
+
978
+ kernel = getkernel(check_rightmul, suffix=dtype.__name__)
979
+
980
+ if register_kernels:
981
+ return
982
+
983
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
984
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
985
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
986
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
987
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
988
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
989
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
990
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
991
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
992
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
993
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
994
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
995
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
996
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
997
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
998
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
999
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1000
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1001
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1002
+ tape = wp.Tape()
1003
+ with tape:
1004
+ wp.launch(
1005
+ kernel,
1006
+ dim=1,
1007
+ inputs=[
1008
+ s,
1009
+ v2,
1010
+ v3,
1011
+ v4,
1012
+ v5,
1013
+ ],
1014
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1015
+ device=device,
1016
+ )
1017
+
1018
+ assert_np_equal(v20.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 0], tol=tol)
1019
+ assert_np_equal(v21.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 1], tol=tol)
1020
+
1021
+ assert_np_equal(v30.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 0], tol=10 * tol)
1022
+ assert_np_equal(v31.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 1], tol=10 * tol)
1023
+ assert_np_equal(v32.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 2], tol=10 * tol)
1024
+
1025
+ assert_np_equal(v40.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 0], tol=10 * tol)
1026
+ assert_np_equal(v41.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 1], tol=10 * tol)
1027
+ assert_np_equal(v42.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 2], tol=10 * tol)
1028
+ assert_np_equal(v43.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 3], tol=10 * tol)
1029
+
1030
+ assert_np_equal(v50.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 0], tol=10 * tol)
1031
+ assert_np_equal(v51.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 1], tol=10 * tol)
1032
+ assert_np_equal(v52.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 2], tol=10 * tol)
1033
+ assert_np_equal(v53.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 3], tol=10 * tol)
1034
+ assert_np_equal(v54.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 4], tol=10 * tol)
1035
+
1036
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1037
+
1038
+ if dtype in np_float_types:
1039
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43]):
1040
+ tape.backward(loss=l)
1041
+ sgrad = tape.gradients[s].numpy()[0]
1042
+ assert_np_equal(sgrad, 2 * incmps[i], tol=10 * tol)
1043
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4]])
1044
+ expected_grads = np.zeros_like(allgrads)
1045
+ expected_grads[i] = s.numpy()[0] * 2
1046
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
1047
+ tape.zero()
1048
+
1049
+
1050
+ def test_cw_multiplication(test, device, dtype, register_kernels=False):
1051
+ rng = np.random.default_rng(123)
1052
+
1053
+ tol = {
1054
+ np.float16: 5.0e-3,
1055
+ np.float32: 1.0e-6,
1056
+ np.float64: 1.0e-8,
1057
+ }.get(dtype, 0)
1058
+
1059
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1060
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1061
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1062
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1063
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1064
+
1065
+ def check_cw_mul(
1066
+ s2: wp.array(dtype=vec2),
1067
+ s3: wp.array(dtype=vec3),
1068
+ s4: wp.array(dtype=vec4),
1069
+ s5: wp.array(dtype=vec5),
1070
+ v2: wp.array(dtype=vec2),
1071
+ v3: wp.array(dtype=vec3),
1072
+ v4: wp.array(dtype=vec4),
1073
+ v5: wp.array(dtype=vec5),
1074
+ v20: wp.array(dtype=wptype),
1075
+ v21: wp.array(dtype=wptype),
1076
+ v30: wp.array(dtype=wptype),
1077
+ v31: wp.array(dtype=wptype),
1078
+ v32: wp.array(dtype=wptype),
1079
+ v40: wp.array(dtype=wptype),
1080
+ v41: wp.array(dtype=wptype),
1081
+ v42: wp.array(dtype=wptype),
1082
+ v43: wp.array(dtype=wptype),
1083
+ v50: wp.array(dtype=wptype),
1084
+ v51: wp.array(dtype=wptype),
1085
+ v52: wp.array(dtype=wptype),
1086
+ v53: wp.array(dtype=wptype),
1087
+ v54: wp.array(dtype=wptype),
1088
+ ):
1089
+ v2result = wp.cw_mul(s2[0], v2[0])
1090
+ v3result = wp.cw_mul(s3[0], v3[0])
1091
+ v4result = wp.cw_mul(s4[0], v4[0])
1092
+ v5result = wp.cw_mul(s5[0], v5[0])
1093
+
1094
+ v20[0] = wptype(2) * v2result[0]
1095
+ v21[0] = wptype(2) * v2result[1]
1096
+
1097
+ v30[0] = wptype(2) * v3result[0]
1098
+ v31[0] = wptype(2) * v3result[1]
1099
+ v32[0] = wptype(2) * v3result[2]
1100
+
1101
+ v40[0] = wptype(2) * v4result[0]
1102
+ v41[0] = wptype(2) * v4result[1]
1103
+ v42[0] = wptype(2) * v4result[2]
1104
+ v43[0] = wptype(2) * v4result[3]
1105
+
1106
+ v50[0] = wptype(2) * v5result[0]
1107
+ v51[0] = wptype(2) * v5result[1]
1108
+ v52[0] = wptype(2) * v5result[2]
1109
+ v53[0] = wptype(2) * v5result[3]
1110
+ v54[0] = wptype(2) * v5result[4]
1111
+
1112
+ kernel = getkernel(check_cw_mul, suffix=dtype.__name__)
1113
+
1114
+ if register_kernels:
1115
+ return
1116
+
1117
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1118
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1119
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1120
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1121
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1122
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1123
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1124
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1125
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1126
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1127
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1128
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1129
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1130
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1131
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1132
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1133
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1134
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1135
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1136
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1137
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1138
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1139
+ tape = wp.Tape()
1140
+ with tape:
1141
+ wp.launch(
1142
+ kernel,
1143
+ dim=1,
1144
+ inputs=[
1145
+ s2,
1146
+ s3,
1147
+ s4,
1148
+ s5,
1149
+ v2,
1150
+ v3,
1151
+ v4,
1152
+ v5,
1153
+ ],
1154
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1155
+ device=device,
1156
+ )
1157
+
1158
+ assert_np_equal(v20.numpy()[0], 2 * s2.numpy()[0, 0] * v2.numpy()[0, 0], tol=10 * tol)
1159
+ assert_np_equal(v21.numpy()[0], 2 * s2.numpy()[0, 1] * v2.numpy()[0, 1], tol=10 * tol)
1160
+
1161
+ assert_np_equal(v30.numpy()[0], 2 * s3.numpy()[0, 0] * v3.numpy()[0, 0], tol=10 * tol)
1162
+ assert_np_equal(v31.numpy()[0], 2 * s3.numpy()[0, 1] * v3.numpy()[0, 1], tol=10 * tol)
1163
+ assert_np_equal(v32.numpy()[0], 2 * s3.numpy()[0, 2] * v3.numpy()[0, 2], tol=10 * tol)
1164
+
1165
+ assert_np_equal(v40.numpy()[0], 2 * s4.numpy()[0, 0] * v4.numpy()[0, 0], tol=10 * tol)
1166
+ assert_np_equal(v41.numpy()[0], 2 * s4.numpy()[0, 1] * v4.numpy()[0, 1], tol=10 * tol)
1167
+ assert_np_equal(v42.numpy()[0], 2 * s4.numpy()[0, 2] * v4.numpy()[0, 2], tol=10 * tol)
1168
+ assert_np_equal(v43.numpy()[0], 2 * s4.numpy()[0, 3] * v4.numpy()[0, 3], tol=10 * tol)
1169
+
1170
+ assert_np_equal(v50.numpy()[0], 2 * s5.numpy()[0, 0] * v5.numpy()[0, 0], tol=10 * tol)
1171
+ assert_np_equal(v51.numpy()[0], 2 * s5.numpy()[0, 1] * v5.numpy()[0, 1], tol=10 * tol)
1172
+ assert_np_equal(v52.numpy()[0], 2 * s5.numpy()[0, 2] * v5.numpy()[0, 2], tol=10 * tol)
1173
+ assert_np_equal(v53.numpy()[0], 2 * s5.numpy()[0, 3] * v5.numpy()[0, 3], tol=10 * tol)
1174
+ assert_np_equal(v54.numpy()[0], 2 * s5.numpy()[0, 4] * v5.numpy()[0, 4], tol=10 * tol)
1175
+
1176
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1177
+ scmps = np.concatenate([v.numpy()[0] for v in [s2, s3, s4, s5]])
1178
+
1179
+ if dtype in np_float_types:
1180
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1181
+ tape.backward(loss=l)
1182
+ sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1183
+ expected_grads = np.zeros_like(sgrads)
1184
+ expected_grads[i] = incmps[i] * 2
1185
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1186
+
1187
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1188
+ expected_grads = np.zeros_like(allgrads)
1189
+ expected_grads[i] = scmps[i] * 2
1190
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
1191
+
1192
+ tape.zero()
1193
+
1194
+
1195
+ def test_scalar_division(test, device, dtype, register_kernels=False):
1196
+ rng = np.random.default_rng(123)
1197
+
1198
+ tol = {
1199
+ np.float16: 5.0e-3,
1200
+ np.float32: 1.0e-6,
1201
+ np.float64: 1.0e-8,
1202
+ }.get(dtype, 0)
1203
+
1204
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1205
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1206
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1207
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1208
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1209
+
1210
+ def check_div(
1211
+ s: wp.array(dtype=wptype),
1212
+ v2: wp.array(dtype=vec2),
1213
+ v3: wp.array(dtype=vec3),
1214
+ v4: wp.array(dtype=vec4),
1215
+ v5: wp.array(dtype=vec5),
1216
+ v20: wp.array(dtype=wptype),
1217
+ v21: wp.array(dtype=wptype),
1218
+ v30: wp.array(dtype=wptype),
1219
+ v31: wp.array(dtype=wptype),
1220
+ v32: wp.array(dtype=wptype),
1221
+ v40: wp.array(dtype=wptype),
1222
+ v41: wp.array(dtype=wptype),
1223
+ v42: wp.array(dtype=wptype),
1224
+ v43: wp.array(dtype=wptype),
1225
+ v50: wp.array(dtype=wptype),
1226
+ v51: wp.array(dtype=wptype),
1227
+ v52: wp.array(dtype=wptype),
1228
+ v53: wp.array(dtype=wptype),
1229
+ v54: wp.array(dtype=wptype),
1230
+ ):
1231
+ v2result = v2[0] / s[0]
1232
+ v3result = v3[0] / s[0]
1233
+ v4result = v4[0] / s[0]
1234
+ v5result = v5[0] / s[0]
1235
+
1236
+ v20[0] = wptype(2) * v2result[0]
1237
+ v21[0] = wptype(2) * v2result[1]
1238
+
1239
+ v30[0] = wptype(2) * v3result[0]
1240
+ v31[0] = wptype(2) * v3result[1]
1241
+ v32[0] = wptype(2) * v3result[2]
1242
+
1243
+ v40[0] = wptype(2) * v4result[0]
1244
+ v41[0] = wptype(2) * v4result[1]
1245
+ v42[0] = wptype(2) * v4result[2]
1246
+ v43[0] = wptype(2) * v4result[3]
1247
+
1248
+ v50[0] = wptype(2) * v5result[0]
1249
+ v51[0] = wptype(2) * v5result[1]
1250
+ v52[0] = wptype(2) * v5result[2]
1251
+ v53[0] = wptype(2) * v5result[3]
1252
+ v54[0] = wptype(2) * v5result[4]
1253
+
1254
+ kernel = getkernel(check_div, suffix=dtype.__name__)
1255
+
1256
+ if register_kernels:
1257
+ return
1258
+
1259
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
1260
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1261
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1262
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1263
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1264
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1265
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1266
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1267
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1268
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1269
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1270
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1271
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1272
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1273
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1274
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1275
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1276
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1277
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1278
+ tape = wp.Tape()
1279
+ with tape:
1280
+ wp.launch(
1281
+ kernel,
1282
+ dim=1,
1283
+ inputs=[
1284
+ s,
1285
+ v2,
1286
+ v3,
1287
+ v4,
1288
+ v5,
1289
+ ],
1290
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1291
+ device=device,
1292
+ )
1293
+
1294
+ if dtype in np_int_types:
1295
+ assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] // (s.numpy()[0])), tol=tol)
1296
+ assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] // (s.numpy()[0])), tol=tol)
1297
+
1298
+ assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1299
+ assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1300
+ assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1301
+
1302
+ assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1303
+ assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1304
+ assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1305
+ assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] // (s.numpy()[0])), tol=10 * tol)
1306
+
1307
+ assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1308
+ assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1309
+ assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1310
+ assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] // (s.numpy()[0])), tol=10 * tol)
1311
+ assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] // (s.numpy()[0])), tol=10 * tol)
1312
+
1313
+ else:
1314
+ assert_np_equal(v20.numpy()[0], 2 * v2.numpy()[0, 0] / (s.numpy()[0]), tol=tol)
1315
+ assert_np_equal(v21.numpy()[0], 2 * v2.numpy()[0, 1] / (s.numpy()[0]), tol=tol)
1316
+
1317
+ assert_np_equal(v30.numpy()[0], 2 * v3.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1318
+ assert_np_equal(v31.numpy()[0], 2 * v3.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1319
+ assert_np_equal(v32.numpy()[0], 2 * v3.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1320
+
1321
+ assert_np_equal(v40.numpy()[0], 2 * v4.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1322
+ assert_np_equal(v41.numpy()[0], 2 * v4.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1323
+ assert_np_equal(v42.numpy()[0], 2 * v4.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1324
+ assert_np_equal(v43.numpy()[0], 2 * v4.numpy()[0, 3] / (s.numpy()[0]), tol=10 * tol)
1325
+
1326
+ assert_np_equal(v50.numpy()[0], 2 * v5.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1327
+ assert_np_equal(v51.numpy()[0], 2 * v5.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1328
+ assert_np_equal(v52.numpy()[0], 2 * v5.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1329
+ assert_np_equal(v53.numpy()[0], 2 * v5.numpy()[0, 3] / (s.numpy()[0]), tol=10 * tol)
1330
+ assert_np_equal(v54.numpy()[0], 2 * v5.numpy()[0, 4] / (s.numpy()[0]), tol=10 * tol)
1331
+
1332
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1333
+
1334
+ if dtype in np_float_types:
1335
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1336
+ tape.backward(loss=l)
1337
+ sgrad = tape.gradients[s].numpy()[0]
1338
+
1339
+ # d/ds v/s = -v/s^2
1340
+ assert_np_equal(sgrad, -2 * incmps[i] / (s.numpy()[0] * s.numpy()[0]), tol=10 * tol)
1341
+
1342
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1343
+ expected_grads = np.zeros_like(allgrads)
1344
+ expected_grads[i] = 2 / s.numpy()[0]
1345
+
1346
+ # d/dv v/s = 1/s
1347
+ assert_np_equal(allgrads, expected_grads, tol=tol)
1348
+ tape.zero()
1349
+
1350
+
1351
+ def test_cw_division(test, device, dtype, register_kernels=False):
1352
+ rng = np.random.default_rng(123)
1353
+
1354
+ tol = {
1355
+ np.float16: 1.0e-2,
1356
+ np.float32: 1.0e-6,
1357
+ np.float64: 1.0e-8,
1358
+ }.get(dtype, 0)
1359
+
1360
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1361
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1362
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1363
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1364
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1365
+
1366
+ def check_cw_div(
1367
+ s2: wp.array(dtype=vec2),
1368
+ s3: wp.array(dtype=vec3),
1369
+ s4: wp.array(dtype=vec4),
1370
+ s5: wp.array(dtype=vec5),
1371
+ v2: wp.array(dtype=vec2),
1372
+ v3: wp.array(dtype=vec3),
1373
+ v4: wp.array(dtype=vec4),
1374
+ v5: wp.array(dtype=vec5),
1375
+ v20: wp.array(dtype=wptype),
1376
+ v21: wp.array(dtype=wptype),
1377
+ v30: wp.array(dtype=wptype),
1378
+ v31: wp.array(dtype=wptype),
1379
+ v32: wp.array(dtype=wptype),
1380
+ v40: wp.array(dtype=wptype),
1381
+ v41: wp.array(dtype=wptype),
1382
+ v42: wp.array(dtype=wptype),
1383
+ v43: wp.array(dtype=wptype),
1384
+ v50: wp.array(dtype=wptype),
1385
+ v51: wp.array(dtype=wptype),
1386
+ v52: wp.array(dtype=wptype),
1387
+ v53: wp.array(dtype=wptype),
1388
+ v54: wp.array(dtype=wptype),
1389
+ ):
1390
+ v2result = wp.cw_div(v2[0], s2[0])
1391
+ v3result = wp.cw_div(v3[0], s3[0])
1392
+ v4result = wp.cw_div(v4[0], s4[0])
1393
+ v5result = wp.cw_div(v5[0], s5[0])
1394
+
1395
+ v20[0] = wptype(2) * v2result[0]
1396
+ v21[0] = wptype(2) * v2result[1]
1397
+
1398
+ v30[0] = wptype(2) * v3result[0]
1399
+ v31[0] = wptype(2) * v3result[1]
1400
+ v32[0] = wptype(2) * v3result[2]
1401
+
1402
+ v40[0] = wptype(2) * v4result[0]
1403
+ v41[0] = wptype(2) * v4result[1]
1404
+ v42[0] = wptype(2) * v4result[2]
1405
+ v43[0] = wptype(2) * v4result[3]
1406
+
1407
+ v50[0] = wptype(2) * v5result[0]
1408
+ v51[0] = wptype(2) * v5result[1]
1409
+ v52[0] = wptype(2) * v5result[2]
1410
+ v53[0] = wptype(2) * v5result[3]
1411
+ v54[0] = wptype(2) * v5result[4]
1412
+
1413
+ kernel = getkernel(check_cw_div, suffix=dtype.__name__)
1414
+
1415
+ if register_kernels:
1416
+ return
1417
+
1418
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1419
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1420
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1421
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1422
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1423
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1424
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1425
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1426
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1427
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1428
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1429
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1430
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1431
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1432
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1433
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1434
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1435
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1436
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1437
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1438
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1439
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1440
+ tape = wp.Tape()
1441
+ with tape:
1442
+ wp.launch(
1443
+ kernel,
1444
+ dim=1,
1445
+ inputs=[
1446
+ s2,
1447
+ s3,
1448
+ s4,
1449
+ s5,
1450
+ v2,
1451
+ v3,
1452
+ v4,
1453
+ v5,
1454
+ ],
1455
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1456
+ device=device,
1457
+ )
1458
+
1459
+ if dtype in np_int_types:
1460
+ assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] // s2.numpy()[0, 0]), tol=tol)
1461
+ assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] // s2.numpy()[0, 1]), tol=tol)
1462
+
1463
+ assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] // s3.numpy()[0, 0]), tol=tol)
1464
+ assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] // s3.numpy()[0, 1]), tol=tol)
1465
+ assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] // s3.numpy()[0, 2]), tol=tol)
1466
+
1467
+ assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] // s4.numpy()[0, 0]), tol=tol)
1468
+ assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] // s4.numpy()[0, 1]), tol=tol)
1469
+ assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] // s4.numpy()[0, 2]), tol=tol)
1470
+ assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] // s4.numpy()[0, 3]), tol=tol)
1471
+
1472
+ assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] // s5.numpy()[0, 0]), tol=tol)
1473
+ assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] // s5.numpy()[0, 1]), tol=tol)
1474
+ assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] // s5.numpy()[0, 2]), tol=tol)
1475
+ assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] // s5.numpy()[0, 3]), tol=tol)
1476
+ assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] // s5.numpy()[0, 4]), tol=tol)
1477
+ else:
1478
+ assert_np_equal(v20.numpy()[0], 2 * v2.numpy()[0, 0] / s2.numpy()[0, 0], tol=tol)
1479
+ assert_np_equal(v21.numpy()[0], 2 * v2.numpy()[0, 1] / s2.numpy()[0, 1], tol=tol)
1480
+
1481
+ assert_np_equal(v30.numpy()[0], 2 * v3.numpy()[0, 0] / s3.numpy()[0, 0], tol=tol)
1482
+ assert_np_equal(v31.numpy()[0], 2 * v3.numpy()[0, 1] / s3.numpy()[0, 1], tol=tol)
1483
+ assert_np_equal(v32.numpy()[0], 2 * v3.numpy()[0, 2] / s3.numpy()[0, 2], tol=tol)
1484
+
1485
+ assert_np_equal(v40.numpy()[0], 2 * v4.numpy()[0, 0] / s4.numpy()[0, 0], tol=tol)
1486
+ assert_np_equal(v41.numpy()[0], 2 * v4.numpy()[0, 1] / s4.numpy()[0, 1], tol=tol)
1487
+ assert_np_equal(v42.numpy()[0], 2 * v4.numpy()[0, 2] / s4.numpy()[0, 2], tol=tol)
1488
+ assert_np_equal(v43.numpy()[0], 2 * v4.numpy()[0, 3] / s4.numpy()[0, 3], tol=tol)
1489
+
1490
+ assert_np_equal(v50.numpy()[0], 2 * v5.numpy()[0, 0] / s5.numpy()[0, 0], tol=tol)
1491
+ assert_np_equal(v51.numpy()[0], 2 * v5.numpy()[0, 1] / s5.numpy()[0, 1], tol=tol)
1492
+ assert_np_equal(v52.numpy()[0], 2 * v5.numpy()[0, 2] / s5.numpy()[0, 2], tol=tol)
1493
+ assert_np_equal(v53.numpy()[0], 2 * v5.numpy()[0, 3] / s5.numpy()[0, 3], tol=tol)
1494
+ assert_np_equal(v54.numpy()[0], 2 * v5.numpy()[0, 4] / s5.numpy()[0, 4], tol=tol)
1495
+
1496
+ if dtype in np_float_types:
1497
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1498
+ scmps = np.concatenate([v.numpy()[0] for v in [s2, s3, s4, s5]])
1499
+
1500
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1501
+ tape.backward(loss=l)
1502
+ sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1503
+ expected_grads = np.zeros_like(sgrads)
1504
+
1505
+ # d/ds v/s = -v/s^2
1506
+ expected_grads[i] = -incmps[i] * 2 / (scmps[i] * scmps[i])
1507
+ assert_np_equal(sgrads, expected_grads, tol=20 * tol)
1508
+
1509
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1510
+ expected_grads = np.zeros_like(allgrads)
1511
+
1512
+ # d/dv v/s = 1/s
1513
+ expected_grads[i] = 2 / scmps[i]
1514
+ assert_np_equal(allgrads, expected_grads, tol=tol)
1515
+
1516
+ tape.zero()
1517
+
1518
+
1519
+ def test_addition(test, device, dtype, register_kernels=False):
1520
+ rng = np.random.default_rng(123)
1521
+
1522
+ tol = {
1523
+ np.float16: 5.0e-3,
1524
+ np.float32: 1.0e-6,
1525
+ np.float64: 1.0e-8,
1526
+ }.get(dtype, 0)
1527
+
1528
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1529
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1530
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1531
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1532
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1533
+
1534
+ def check_add(
1535
+ s2: wp.array(dtype=vec2),
1536
+ s3: wp.array(dtype=vec3),
1537
+ s4: wp.array(dtype=vec4),
1538
+ s5: wp.array(dtype=vec5),
1539
+ v2: wp.array(dtype=vec2),
1540
+ v3: wp.array(dtype=vec3),
1541
+ v4: wp.array(dtype=vec4),
1542
+ v5: wp.array(dtype=vec5),
1543
+ v20: wp.array(dtype=wptype),
1544
+ v21: wp.array(dtype=wptype),
1545
+ v30: wp.array(dtype=wptype),
1546
+ v31: wp.array(dtype=wptype),
1547
+ v32: wp.array(dtype=wptype),
1548
+ v40: wp.array(dtype=wptype),
1549
+ v41: wp.array(dtype=wptype),
1550
+ v42: wp.array(dtype=wptype),
1551
+ v43: wp.array(dtype=wptype),
1552
+ v50: wp.array(dtype=wptype),
1553
+ v51: wp.array(dtype=wptype),
1554
+ v52: wp.array(dtype=wptype),
1555
+ v53: wp.array(dtype=wptype),
1556
+ v54: wp.array(dtype=wptype),
1557
+ ):
1558
+ v2result = v2[0] + s2[0]
1559
+ v3result = v3[0] + s3[0]
1560
+ v4result = v4[0] + s4[0]
1561
+ v5result = v5[0] + s5[0]
1562
+
1563
+ v20[0] = wptype(2) * v2result[0]
1564
+ v21[0] = wptype(2) * v2result[1]
1565
+
1566
+ v30[0] = wptype(2) * v3result[0]
1567
+ v31[0] = wptype(2) * v3result[1]
1568
+ v32[0] = wptype(2) * v3result[2]
1569
+
1570
+ v40[0] = wptype(2) * v4result[0]
1571
+ v41[0] = wptype(2) * v4result[1]
1572
+ v42[0] = wptype(2) * v4result[2]
1573
+ v43[0] = wptype(2) * v4result[3]
1574
+
1575
+ v50[0] = wptype(2) * v5result[0]
1576
+ v51[0] = wptype(2) * v5result[1]
1577
+ v52[0] = wptype(2) * v5result[2]
1578
+ v53[0] = wptype(2) * v5result[3]
1579
+ v54[0] = wptype(2) * v5result[4]
1580
+
1581
+ kernel = getkernel(check_add, suffix=dtype.__name__)
1582
+
1583
+ if register_kernels:
1584
+ return
1585
+
1586
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1587
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1588
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1589
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1590
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1591
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1592
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1593
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1594
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1595
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1596
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1597
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1598
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1599
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1600
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1601
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1602
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1603
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1604
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1605
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1606
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1607
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1608
+ tape = wp.Tape()
1609
+ with tape:
1610
+ wp.launch(
1611
+ kernel,
1612
+ dim=1,
1613
+ inputs=[
1614
+ s2,
1615
+ s3,
1616
+ s4,
1617
+ s5,
1618
+ v2,
1619
+ v3,
1620
+ v4,
1621
+ v5,
1622
+ ],
1623
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1624
+ device=device,
1625
+ )
1626
+
1627
+ assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] + s2.numpy()[0, 0]), tol=tol)
1628
+ assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] + s2.numpy()[0, 1]), tol=tol)
1629
+
1630
+ assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] + s3.numpy()[0, 0]), tol=tol)
1631
+ assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] + s3.numpy()[0, 1]), tol=tol)
1632
+ assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] + s3.numpy()[0, 2]), tol=tol)
1633
+
1634
+ assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] + s4.numpy()[0, 0]), tol=tol)
1635
+ assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] + s4.numpy()[0, 1]), tol=tol)
1636
+ assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] + s4.numpy()[0, 2]), tol=tol)
1637
+ assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] + s4.numpy()[0, 3]), tol=tol)
1638
+
1639
+ assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] + s5.numpy()[0, 0]), tol=tol)
1640
+ assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] + s5.numpy()[0, 1]), tol=tol)
1641
+ assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] + s5.numpy()[0, 2]), tol=tol)
1642
+ assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] + s5.numpy()[0, 3]), tol=tol)
1643
+ assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] + s5.numpy()[0, 4]), tol=2 * tol)
1644
+
1645
+ if dtype in np_float_types:
1646
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1647
+ tape.backward(loss=l)
1648
+ sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1649
+ expected_grads = np.zeros_like(sgrads)
1650
+
1651
+ expected_grads[i] = 2
1652
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1653
+
1654
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1655
+ assert_np_equal(allgrads, expected_grads, tol=tol)
1656
+
1657
+ tape.zero()
1658
+
1659
+
1660
+ def test_dotproduct(test, device, dtype, register_kernels=False):
1661
+ rng = np.random.default_rng(123)
1662
+
1663
+ tol = {
1664
+ np.float16: 1.0e-2,
1665
+ np.float32: 1.0e-6,
1666
+ np.float64: 1.0e-8,
1667
+ }.get(dtype, 0)
1668
+
1669
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1670
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1671
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1672
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1673
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1674
+
1675
+ def check_dot(
1676
+ s2: wp.array(dtype=vec2),
1677
+ s3: wp.array(dtype=vec3),
1678
+ s4: wp.array(dtype=vec4),
1679
+ s5: wp.array(dtype=vec5),
1680
+ v2: wp.array(dtype=vec2),
1681
+ v3: wp.array(dtype=vec3),
1682
+ v4: wp.array(dtype=vec4),
1683
+ v5: wp.array(dtype=vec5),
1684
+ dot2: wp.array(dtype=wptype),
1685
+ dot3: wp.array(dtype=wptype),
1686
+ dot4: wp.array(dtype=wptype),
1687
+ dot5: wp.array(dtype=wptype),
1688
+ ):
1689
+ dot2[0] = wptype(2) * wp.dot(v2[0], s2[0])
1690
+ dot3[0] = wptype(2) * wp.dot(v3[0], s3[0])
1691
+ dot4[0] = wptype(2) * wp.dot(v4[0], s4[0])
1692
+ dot5[0] = wptype(2) * wp.dot(v5[0], s5[0])
1693
+
1694
+ kernel = getkernel(check_dot, suffix=dtype.__name__)
1695
+
1696
+ if register_kernels:
1697
+ return
1698
+
1699
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1700
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1701
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1702
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1703
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1704
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1705
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1706
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1707
+ dot2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1708
+ dot3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1709
+ dot4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1710
+ dot5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1711
+ tape = wp.Tape()
1712
+ with tape:
1713
+ wp.launch(
1714
+ kernel,
1715
+ dim=1,
1716
+ inputs=[
1717
+ s2,
1718
+ s3,
1719
+ s4,
1720
+ s5,
1721
+ v2,
1722
+ v3,
1723
+ v4,
1724
+ v5,
1725
+ ],
1726
+ outputs=[dot2, dot3, dot4, dot5],
1727
+ device=device,
1728
+ )
1729
+
1730
+ assert_np_equal(dot2.numpy()[0], 2.0 * (v2.numpy() * s2.numpy()).sum(), tol=10 * tol)
1731
+ assert_np_equal(dot3.numpy()[0], 2.0 * (v3.numpy() * s3.numpy()).sum(), tol=10 * tol)
1732
+ assert_np_equal(dot4.numpy()[0], 2.0 * (v4.numpy() * s4.numpy()).sum(), tol=10 * tol)
1733
+ assert_np_equal(dot5.numpy()[0], 2.0 * (v5.numpy() * s5.numpy()).sum(), tol=10 * tol)
1734
+
1735
+ if dtype in np_float_types:
1736
+ tape.backward(loss=dot2)
1737
+ sgrads = tape.gradients[s2].numpy()[0]
1738
+ expected_grads = 2.0 * v2.numpy()[0]
1739
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1740
+
1741
+ vgrads = tape.gradients[v2].numpy()[0]
1742
+ expected_grads = 2.0 * s2.numpy()[0]
1743
+ assert_np_equal(vgrads, expected_grads, tol=tol)
1744
+
1745
+ tape.zero()
1746
+
1747
+ tape.backward(loss=dot3)
1748
+ sgrads = tape.gradients[s3].numpy()[0]
1749
+ expected_grads = 2.0 * v3.numpy()[0]
1750
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1751
+
1752
+ vgrads = tape.gradients[v3].numpy()[0]
1753
+ expected_grads = 2.0 * s3.numpy()[0]
1754
+ assert_np_equal(vgrads, expected_grads, tol=tol)
1755
+
1756
+ tape.zero()
1757
+
1758
+ tape.backward(loss=dot4)
1759
+ sgrads = tape.gradients[s4].numpy()[0]
1760
+ expected_grads = 2.0 * v4.numpy()[0]
1761
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1762
+
1763
+ vgrads = tape.gradients[v4].numpy()[0]
1764
+ expected_grads = 2.0 * s4.numpy()[0]
1765
+ assert_np_equal(vgrads, expected_grads, tol=tol)
1766
+
1767
+ tape.zero()
1768
+
1769
+ tape.backward(loss=dot5)
1770
+ sgrads = tape.gradients[s5].numpy()[0]
1771
+ expected_grads = 2.0 * v5.numpy()[0]
1772
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1773
+
1774
+ vgrads = tape.gradients[v5].numpy()[0]
1775
+ expected_grads = 2.0 * s5.numpy()[0]
1776
+ assert_np_equal(vgrads, expected_grads, tol=10 * tol)
1777
+
1778
+ tape.zero()
1779
+
1780
+
1781
+ def test_equivalent_types(test, device, dtype, register_kernels=False):
1782
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1783
+
1784
+ # vector types
1785
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1786
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1787
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1788
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1789
+
1790
+ # vector types equivalent to the above
1791
+ vec2_equiv = wp.types.vector(length=2, dtype=wptype)
1792
+ vec3_equiv = wp.types.vector(length=3, dtype=wptype)
1793
+ vec4_equiv = wp.types.vector(length=4, dtype=wptype)
1794
+ vec5_equiv = wp.types.vector(length=5, dtype=wptype)
1795
+
1796
+ # declare kernel with original types
1797
+ def check_equivalence(
1798
+ v2: vec2,
1799
+ v3: vec3,
1800
+ v4: vec4,
1801
+ v5: vec5,
1802
+ ):
1803
+ wp.expect_eq(v2, vec2(wptype(1), wptype(2)))
1804
+ wp.expect_eq(v3, vec3(wptype(1), wptype(2), wptype(3)))
1805
+ wp.expect_eq(v4, vec4(wptype(1), wptype(2), wptype(3), wptype(4)))
1806
+ wp.expect_eq(v5, vec5(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1807
+
1808
+ wp.expect_eq(v2, vec2_equiv(wptype(1), wptype(2)))
1809
+ wp.expect_eq(v3, vec3_equiv(wptype(1), wptype(2), wptype(3)))
1810
+ wp.expect_eq(v4, vec4_equiv(wptype(1), wptype(2), wptype(3), wptype(4)))
1811
+ wp.expect_eq(v5, vec5_equiv(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1812
+
1813
+ kernel = getkernel(check_equivalence, suffix=dtype.__name__)
1814
+
1815
+ if register_kernels:
1816
+ return
1817
+
1818
+ # call kernel with equivalent types
1819
+ v2 = vec2_equiv(1, 2)
1820
+ v3 = vec3_equiv(1, 2, 3)
1821
+ v4 = vec4_equiv(1, 2, 3, 4)
1822
+ v5 = vec5_equiv(1, 2, 3, 4, 5)
1823
+
1824
+ wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5], device=device)
1825
+
1826
+
1827
+ def test_conversions(test, device, dtype, register_kernels=False):
1828
+ def check_vectors_equal(
1829
+ v0: wp.vec3,
1830
+ v1: wp.vec3,
1831
+ v2: wp.vec3,
1832
+ v3: wp.vec3,
1833
+ ):
1834
+ wp.expect_eq(v1, v0)
1835
+ wp.expect_eq(v2, v0)
1836
+ wp.expect_eq(v3, v0)
1837
+
1838
+ kernel = getkernel(check_vectors_equal, suffix=dtype.__name__)
1839
+
1840
+ if register_kernels:
1841
+ return
1842
+
1843
+ v0 = wp.vec3(1, 2, 3)
1844
+
1845
+ # test explicit conversions - constructing vectors from different containers
1846
+ v1 = wp.vec3((1, 2, 3))
1847
+ v2 = wp.vec3([1, 2, 3])
1848
+ v3 = wp.vec3(np.array([1, 2, 3], dtype=dtype))
1849
+
1850
+ wp.launch(kernel, dim=1, inputs=[v0, v1, v2, v3], device=device)
1851
+
1852
+ # test implicit conversions - passing different containers as vectors to wp.launch()
1853
+ v1 = (1, 2, 3)
1854
+ v2 = [1, 2, 3]
1855
+ v3 = np.array([1, 2, 3], dtype=dtype)
1856
+
1857
+ wp.launch(kernel, dim=1, inputs=[v0, v1, v2, v3], device=device)
1858
+
1859
+
1860
+ def test_constants(test, device, dtype, register_kernels=False):
1861
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1862
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1863
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1864
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1865
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1866
+
1867
+ cv2 = wp.constant(vec2(1, 2))
1868
+ cv3 = wp.constant(vec3(1, 2, 3))
1869
+ cv4 = wp.constant(vec4(1, 2, 3, 4))
1870
+ cv5 = wp.constant(vec5(1, 2, 3, 4, 5))
1871
+
1872
+ def check_vector_constants():
1873
+ wp.expect_eq(cv2, vec2(wptype(1), wptype(2)))
1874
+ wp.expect_eq(cv3, vec3(wptype(1), wptype(2), wptype(3)))
1875
+ wp.expect_eq(cv4, vec4(wptype(1), wptype(2), wptype(3), wptype(4)))
1876
+ wp.expect_eq(cv5, vec5(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1877
+
1878
+ kernel = getkernel(check_vector_constants, suffix=dtype.__name__)
1879
+
1880
+ if register_kernels:
1881
+ return
1882
+
1883
+ wp.launch(kernel, dim=1, inputs=[], device=device)
1884
+
1885
+
1886
+ def test_minmax(test, device, dtype, register_kernels=False):
1887
+ rng = np.random.default_rng(123)
1888
+
1889
+ # \TODO: not quite sure why, but the numbers are off for 16 bit float
1890
+ # on the cpu (but not cuda). This is probably just the sketchy float16
1891
+ # arithmetic I implemented to get all this stuff working, so
1892
+ # hopefully that can be fixed when we do that correctly.
1893
+ tol = {
1894
+ np.float16: 1.0e-2,
1895
+ }.get(dtype, 0)
1896
+
1897
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1898
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1899
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1900
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1901
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1902
+
1903
+ # \TODO: Also not quite sure why: this kernel compiles incredibly
1904
+ # slowly though...
1905
+ def check_vec_min_max(
1906
+ a: wp.array(dtype=wptype, ndim=2),
1907
+ b: wp.array(dtype=wptype, ndim=2),
1908
+ mins: wp.array(dtype=wptype, ndim=2),
1909
+ maxs: wp.array(dtype=wptype, ndim=2),
1910
+ ):
1911
+ for i in range(10):
1912
+ # multiplying by 2 so we've got something to backpropagate:
1913
+ a2read = vec2(a[i, 0], a[i, 1])
1914
+ b2read = vec2(b[i, 0], b[i, 1])
1915
+ c2 = wptype(2) * wp.min(a2read, b2read)
1916
+ d2 = wptype(2) * wp.max(a2read, b2read)
1917
+
1918
+ a3read = vec3(a[i, 2], a[i, 3], a[i, 4])
1919
+ b3read = vec3(b[i, 2], b[i, 3], b[i, 4])
1920
+ c3 = wptype(2) * wp.min(a3read, b3read)
1921
+ d3 = wptype(2) * wp.max(a3read, b3read)
1922
+
1923
+ a4read = vec4(a[i, 5], a[i, 6], a[i, 7], a[i, 8])
1924
+ b4read = vec4(b[i, 5], b[i, 6], b[i, 7], b[i, 8])
1925
+ c4 = wptype(2) * wp.min(a4read, b4read)
1926
+ d4 = wptype(2) * wp.max(a4read, b4read)
1927
+
1928
+ a5read = vec5(a[i, 9], a[i, 10], a[i, 11], a[i, 12], a[i, 13])
1929
+ b5read = vec5(b[i, 9], b[i, 10], b[i, 11], b[i, 12], b[i, 13])
1930
+ c5 = wptype(2) * wp.min(a5read, b5read)
1931
+ d5 = wptype(2) * wp.max(a5read, b5read)
1932
+
1933
+ mins[i, 0] = c2[0]
1934
+ mins[i, 1] = c2[1]
1935
+
1936
+ mins[i, 2] = c3[0]
1937
+ mins[i, 3] = c3[1]
1938
+ mins[i, 4] = c3[2]
1939
+
1940
+ mins[i, 5] = c4[0]
1941
+ mins[i, 6] = c4[1]
1942
+ mins[i, 7] = c4[2]
1943
+ mins[i, 8] = c4[3]
1944
+
1945
+ mins[i, 9] = c5[0]
1946
+ mins[i, 10] = c5[1]
1947
+ mins[i, 11] = c5[2]
1948
+ mins[i, 12] = c5[3]
1949
+ mins[i, 13] = c5[4]
1950
+
1951
+ maxs[i, 0] = d2[0]
1952
+ maxs[i, 1] = d2[1]
1953
+
1954
+ maxs[i, 2] = d3[0]
1955
+ maxs[i, 3] = d3[1]
1956
+ maxs[i, 4] = d3[2]
1957
+
1958
+ maxs[i, 5] = d4[0]
1959
+ maxs[i, 6] = d4[1]
1960
+ maxs[i, 7] = d4[2]
1961
+ maxs[i, 8] = d4[3]
1962
+
1963
+ maxs[i, 9] = d5[0]
1964
+ maxs[i, 10] = d5[1]
1965
+ maxs[i, 11] = d5[2]
1966
+ maxs[i, 12] = d5[3]
1967
+ maxs[i, 13] = d5[4]
1968
+
1969
+ kernel = getkernel(check_vec_min_max, suffix=dtype.__name__)
1970
+ output_select_kernel = get_select_kernel2(wptype)
1971
+
1972
+ if register_kernels:
1973
+ return
1974
+
1975
+ a = wp.array(randvals(rng, (10, 14), dtype), dtype=wptype, requires_grad=True, device=device)
1976
+ b = wp.array(randvals(rng, (10, 14), dtype), dtype=wptype, requires_grad=True, device=device)
1977
+
1978
+ mins = wp.zeros((10, 14), dtype=wptype, requires_grad=True, device=device)
1979
+ maxs = wp.zeros((10, 14), dtype=wptype, requires_grad=True, device=device)
1980
+
1981
+ tape = wp.Tape()
1982
+ with tape:
1983
+ wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
1984
+
1985
+ assert_np_equal(mins.numpy(), 2 * np.minimum(a.numpy(), b.numpy()), tol=tol)
1986
+ assert_np_equal(maxs.numpy(), 2 * np.maximum(a.numpy(), b.numpy()), tol=tol)
1987
+
1988
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1989
+ if dtype in np_float_types:
1990
+ for i in range(10):
1991
+ for j in range(14):
1992
+ tape = wp.Tape()
1993
+ with tape:
1994
+ wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
1995
+ wp.launch(output_select_kernel, dim=1, inputs=[mins, i, j], outputs=[out], device=device)
1996
+
1997
+ tape.backward(loss=out)
1998
+ expected = np.zeros_like(a.numpy())
1999
+ expected[i, j] = 2 if (a.numpy()[i, j] < b.numpy()[i, j]) else 0
2000
+ assert_np_equal(tape.gradients[a].numpy(), expected, tol=tol)
2001
+ expected[i, j] = 2 if (b.numpy()[i, j] < a.numpy()[i, j]) else 0
2002
+ assert_np_equal(tape.gradients[b].numpy(), expected, tol=tol)
2003
+ tape.zero()
2004
+
2005
+ tape = wp.Tape()
2006
+ with tape:
2007
+ wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
2008
+ wp.launch(output_select_kernel, dim=1, inputs=[maxs, i, j], outputs=[out], device=device)
2009
+
2010
+ tape.backward(loss=out)
2011
+ expected = np.zeros_like(a.numpy())
2012
+ expected[i, j] = 2 if (a.numpy()[i, j] > b.numpy()[i, j]) else 0
2013
+ assert_np_equal(tape.gradients[a].numpy(), expected, tol=tol)
2014
+ expected[i, j] = 2 if (b.numpy()[i, j] > a.numpy()[i, j]) else 0
2015
+ assert_np_equal(tape.gradients[b].numpy(), expected, tol=tol)
2016
+ tape.zero()
2017
+
2018
+
2019
+ devices = get_test_devices()
2020
+
2021
+
2022
+ class TestVecScalarOps(unittest.TestCase):
2023
+ pass
2024
+
2025
+
2026
+ for dtype in np_scalar_types:
2027
+ add_function_test(TestVecScalarOps, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
2028
+ add_function_test(TestVecScalarOps, f"test_components_{dtype.__name__}", test_components, devices=None, dtype=dtype)
2029
+ add_function_test(
2030
+ TestVecScalarOps, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2031
+ )
2032
+ add_function_test_register_kernel(
2033
+ TestVecScalarOps, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
2034
+ )
2035
+ add_function_test_register_kernel(
2036
+ TestVecScalarOps,
2037
+ f"test_anon_type_instance_{dtype.__name__}",
2038
+ test_anon_type_instance,
2039
+ devices=devices,
2040
+ dtype=dtype,
2041
+ )
2042
+ add_function_test_register_kernel(
2043
+ TestVecScalarOps, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2044
+ )
2045
+ add_function_test_register_kernel(
2046
+ TestVecScalarOps, f"test_equality_{dtype.__name__}", test_equality, devices=devices, dtype=dtype
2047
+ )
2048
+ add_function_test_register_kernel(
2049
+ TestVecScalarOps,
2050
+ f"test_scalar_multiplication_{dtype.__name__}",
2051
+ test_scalar_multiplication,
2052
+ devices=devices,
2053
+ dtype=dtype,
2054
+ )
2055
+ add_function_test_register_kernel(
2056
+ TestVecScalarOps,
2057
+ f"test_scalar_multiplication_rightmul_{dtype.__name__}",
2058
+ test_scalar_multiplication_rightmul,
2059
+ devices=devices,
2060
+ dtype=dtype,
2061
+ )
2062
+ add_function_test_register_kernel(
2063
+ TestVecScalarOps,
2064
+ f"test_cw_multiplication_{dtype.__name__}",
2065
+ test_cw_multiplication,
2066
+ devices=devices,
2067
+ dtype=dtype,
2068
+ )
2069
+ add_function_test_register_kernel(
2070
+ TestVecScalarOps, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2071
+ )
2072
+ add_function_test_register_kernel(
2073
+ TestVecScalarOps, f"test_cw_division_{dtype.__name__}", test_cw_division, devices=devices, dtype=dtype
2074
+ )
2075
+ add_function_test_register_kernel(
2076
+ TestVecScalarOps, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2077
+ )
2078
+ add_function_test_register_kernel(
2079
+ TestVecScalarOps, f"test_dotproduct_{dtype.__name__}", test_dotproduct, devices=devices, dtype=dtype
2080
+ )
2081
+ add_function_test_register_kernel(
2082
+ TestVecScalarOps, f"test_equivalent_types_{dtype.__name__}", test_equivalent_types, devices=devices, dtype=dtype
2083
+ )
2084
+ add_function_test_register_kernel(
2085
+ TestVecScalarOps, f"test_conversions_{dtype.__name__}", test_conversions, devices=devices, dtype=dtype
2086
+ )
2087
+ add_function_test_register_kernel(
2088
+ TestVecScalarOps, f"test_constants_{dtype.__name__}", test_constants, devices=devices, dtype=dtype
2089
+ )
2090
+
2091
+ # the kernels in this test compile incredibly slowly...
2092
+ # add_function_test_register_kernel(TestVecScalarOps, f"test_minmax_{dtype.__name__}", test_minmax, devices=devices, dtype=dtype)
2093
+
2094
+
2095
+ if __name__ == "__main__":
2096
+ wp.build.clear_kernel_cache()
2097
+ unittest.main(verbosity=2, failfast=True)