warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/tests/test_array.py CHANGED
@@ -5,14 +5,12 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- # include parent path
8
+ import unittest
9
+
9
10
  import numpy as np
10
- import math
11
11
 
12
12
  import warp as wp
13
- from warp.tests.test_base import *
14
-
15
- import unittest
13
+ from warp.tests.unittest_utils import *
16
14
 
17
15
  wp.init()
18
16
 
@@ -397,7 +395,7 @@ def test_slicing(test, device):
397
395
  assert_array_equal(wp_arr[:5], wp.array(np_arr[:5], dtype=int, device=device))
398
396
  assert_array_equal(wp_arr[1:5], wp.array(np_arr[1:5], dtype=int, device=device))
399
397
  assert_array_equal(wp_arr[-9:-5:1], wp.array(np_arr[-9:-5:1], dtype=int, device=device))
400
- assert_array_equal(wp_arr[:5,], wp.array(np_arr[:5], dtype=int, device=device))
398
+ assert_array_equal(wp_arr[:5,], wp.array(np_arr[:5], dtype=int, device=device)) # noqa: E231
401
399
 
402
400
 
403
401
  def test_view(test, device):
@@ -738,7 +736,10 @@ def test_fill_matrix(test, device):
738
736
  assert_np_equal(a4.numpy(), np.zeros((*a4.shape, *mat_shape), dtype=nptype))
739
737
 
740
738
  # matrix values can be passed as a 1d numpy array, 2d numpy array, flat list, nested list, or Warp matrix instance
741
- fill_arr1 = np.arange(mat_len, dtype=nptype)
739
+ if wptype != wp.bool:
740
+ fill_arr1 = np.arange(mat_len, dtype=nptype)
741
+ else:
742
+ fill_arr1 = np.ones(mat_len, dtype=nptype)
742
743
  fill_arr2 = fill_arr1.reshape(mat_shape)
743
744
  fill_list1 = list(fill_arr1)
744
745
  fill_list2 = [list(row) for row in fill_arr2]
@@ -1295,7 +1296,10 @@ def test_full_matrix(test, device):
1295
1296
  assert_np_equal(na, np.full(a.size * mattype._length_, 42, dtype=nptype).reshape(npshape))
1296
1297
 
1297
1298
  # fill with 1d numpy array and specific dtype
1298
- fill_arr1d = np.arange(mattype._length_, dtype=nptype)
1299
+ if wptype != wp.bool:
1300
+ fill_arr1d = np.arange(mattype._length_, dtype=nptype)
1301
+ else:
1302
+ fill_arr1d = np.ones(mattype._length_, dtype=nptype)
1299
1303
  a = wp.full(shape, fill_arr1d, dtype=mattype, device=device)
1300
1304
  na = a.numpy()
1301
1305
 
@@ -1448,16 +1452,17 @@ def test_full_struct(test, device):
1448
1452
 
1449
1453
 
1450
1454
  def test_round_trip(test, device):
1455
+ rng = np.random.default_rng(123)
1451
1456
  dim_x = 4
1452
1457
 
1453
1458
  for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1454
- a_np = np.random.randn(dim_x).astype(nptype)
1459
+ a_np = rng.standard_normal(size=dim_x).astype(nptype)
1455
1460
  a = wp.array(a_np, device=device)
1456
1461
  test.assertEqual(a.dtype, wptype)
1457
1462
 
1458
1463
  assert_np_equal(a.numpy(), a_np)
1459
1464
 
1460
- v_np = np.random.randn(dim_x, 3).astype(nptype)
1465
+ v_np = rng.standard_normal(size=(dim_x, 3)).astype(nptype)
1461
1466
  v = wp.array(v_np, dtype=wp.types.vector(3, wptype), device=device)
1462
1467
 
1463
1468
  assert_np_equal(v.numpy(), v_np)
@@ -1695,6 +1700,7 @@ def test_to_list_struct(test, device):
1695
1700
  a1: wp.array(dtype=int)
1696
1701
  a2: wp.array2d(dtype=float)
1697
1702
  a3: wp.array3d(dtype=wp.float16)
1703
+ bool: wp.bool
1698
1704
 
1699
1705
  dim = 3
1700
1706
 
@@ -1714,6 +1720,7 @@ def test_to_list_struct(test, device):
1714
1720
  s.a1 = wp.empty(1, dtype=int, device=device)
1715
1721
  s.a2 = wp.empty((1, 1), dtype=float, device=device)
1716
1722
  s.a3 = wp.empty((1, 1, 1), dtype=wp.float16, device=device)
1723
+ s.bool = True
1717
1724
 
1718
1725
  for ndim in range(1, 5):
1719
1726
  shape = (dim,) * ndim
@@ -1731,6 +1738,7 @@ def test_to_list_struct(test, device):
1731
1738
  test.assertEqual(l[i].mi, s.mi)
1732
1739
  test.assertEqual(l[i].mf, s.mf)
1733
1740
  test.assertEqual(l[i].mh, s.mh)
1741
+ test.assertEqual(l[i].bool, s.bool)
1734
1742
  test.assertEqual(l[i].inner.h, s.inner.h)
1735
1743
  test.assertEqual(l[i].inner.v, s.inner.v)
1736
1744
  test.assertEqual(l[i].a1.dtype, s.a1.dtype)
@@ -1741,46 +1749,6 @@ def test_to_list_struct(test, device):
1741
1749
  test.assertEqual(l[i].a3.ndim, s.a3.ndim)
1742
1750
 
1743
1751
 
1744
- def test_large_arrays_slow(test, device):
1745
- # The goal of this test is to use arrays just large enough to know
1746
- # if there's a flaw in handling arrays with more than 2**31-1 elements
1747
- # Unfortunately, it takes a long time to run so it won't be run automatically
1748
- # without changes to support how frequently a test may be run
1749
- total_elements = 2**31 + 8
1750
-
1751
- # 1-D to 4-D arrays: test zero_, fill_, then zero_ for scalar data types:
1752
- for total_dims in range(1, 5):
1753
- dim_x = math.ceil(total_elements ** (1 / total_dims))
1754
- shape_tuple = tuple([dim_x] * total_dims)
1755
-
1756
- for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1757
- a1 = wp.zeros(shape_tuple, dtype=wptype, device=device)
1758
- assert_np_equal(a1.numpy(), np.zeros_like(a1.numpy()))
1759
-
1760
- a1.fill_(127)
1761
- assert_np_equal(a1.numpy(), 127 * np.ones_like(a1.numpy()))
1762
-
1763
- a1.zero_()
1764
- assert_np_equal(a1.numpy(), np.zeros_like(a1.numpy()))
1765
-
1766
-
1767
- def test_large_arrays_fast(test, device):
1768
- # A truncated version of test_large_arrays_slow meant to catch basic errors
1769
- total_elements = 2**31 + 8
1770
-
1771
- nptype = np.dtype(np.int8)
1772
- wptype = wp.types.np_dtype_to_warp_type[nptype]
1773
-
1774
- a1 = wp.zeros((total_elements,), dtype=wptype, device=device)
1775
- assert_np_equal(a1.numpy(), np.zeros_like(a1.numpy()))
1776
-
1777
- a1.fill_(127)
1778
- assert_np_equal(a1.numpy(), 127 * np.ones_like(a1.numpy()))
1779
-
1780
- a1.zero_()
1781
- assert_np_equal(a1.numpy(), np.zeros_like(a1.numpy()))
1782
-
1783
-
1784
1752
  @wp.kernel
1785
1753
  def kernel_array_to_bool(array_null: wp.array(dtype=float), array_valid: wp.array(dtype=float)):
1786
1754
  if not array_null:
@@ -1969,54 +1937,201 @@ def test_array_of_structs_roundtrip(test, device):
1969
1937
  assert_np_equal(a.numpy(), expected)
1970
1938
 
1971
1939
 
1972
- def register(parent):
1973
- devices = get_test_devices()
1974
-
1975
- class TestArray(parent):
1976
- pass
1977
-
1978
- add_function_test(TestArray, "test_shape", test_shape, devices=devices)
1979
- add_function_test(TestArray, "test_flatten", test_flatten, devices=devices)
1980
- add_function_test(TestArray, "test_reshape", test_reshape, devices=devices)
1981
- add_function_test(TestArray, "test_slicing", test_slicing, devices=devices)
1982
- add_function_test(TestArray, "test_transpose", test_transpose, devices=devices)
1983
- add_function_test(TestArray, "test_view", test_view, devices=devices)
1984
-
1985
- add_function_test(TestArray, "test_1d_array", test_1d, devices=devices)
1986
- add_function_test(TestArray, "test_2d_array", test_2d, devices=devices)
1987
- add_function_test(TestArray, "test_3d_array", test_3d, devices=devices)
1988
- add_function_test(TestArray, "test_4d_array", test_4d, devices=devices)
1989
- add_function_test(TestArray, "test_4d_array_transposed", test_4d_transposed, devices=devices)
1990
-
1991
- add_function_test(TestArray, "test_fill_scalar", test_fill_scalar, devices=devices)
1992
- add_function_test(TestArray, "test_fill_vector", test_fill_vector, devices=devices)
1993
- add_function_test(TestArray, "test_fill_matrix", test_fill_matrix, devices=devices)
1994
- add_function_test(TestArray, "test_fill_struct", test_fill_struct, devices=devices)
1995
- add_function_test(TestArray, "test_fill_slices", test_fill_slices, devices=devices)
1996
- add_function_test(TestArray, "test_full_scalar", test_full_scalar, devices=devices)
1997
- add_function_test(TestArray, "test_full_vector", test_full_vector, devices=devices)
1998
- add_function_test(TestArray, "test_full_matrix", test_full_matrix, devices=devices)
1999
- add_function_test(TestArray, "test_full_struct", test_full_struct, devices=devices)
2000
- add_function_test(TestArray, "test_empty_array", test_empty_array, devices=devices)
2001
- add_function_test(TestArray, "test_empty_from_numpy", test_empty_from_numpy, devices=devices)
2002
- add_function_test(TestArray, "test_empty_from_list", test_empty_from_list, devices=devices)
2003
- add_function_test(TestArray, "test_to_list_scalar", test_to_list_scalar, devices=devices)
2004
- add_function_test(TestArray, "test_to_list_vector", test_to_list_vector, devices=devices)
2005
- add_function_test(TestArray, "test_to_list_matrix", test_to_list_matrix, devices=devices)
2006
- add_function_test(TestArray, "test_to_list_struct", test_to_list_struct, devices=devices)
2007
-
2008
- add_function_test(TestArray, "test_lower_bound", test_lower_bound, devices=devices)
2009
- add_function_test(TestArray, "test_round_trip", test_round_trip, devices=devices)
2010
- add_function_test(TestArray, "test_large_arrays_fast", test_large_arrays_fast, devices=devices)
2011
- add_function_test(TestArray, "test_array_to_bool", test_array_to_bool, devices=devices)
2012
- add_function_test(TestArray, "test_array_of_structs", test_array_of_structs, devices=devices)
2013
- add_function_test(TestArray, "test_array_of_structs_grad", test_array_of_structs_grad, devices=devices)
2014
- add_function_test(TestArray, "test_array_of_structs_from_numpy", test_array_of_structs_from_numpy, devices=devices)
2015
- add_function_test(TestArray, "test_array_of_structs_roundtrip", test_array_of_structs_roundtrip, devices=devices)
2016
-
2017
- return TestArray
1940
+ def test_array_from_numpy(test, device):
1941
+ arr = np.array((1.0, 2.0, 3.0), dtype=float)
1942
+
1943
+ result = wp.from_numpy(arr)
1944
+ expected = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, shape=(3,))
1945
+ assert_np_equal(result.numpy(), expected.numpy())
1946
+
1947
+ result = wp.from_numpy(arr, dtype=wp.vec3)
1948
+ expected = wp.array(((1.0, 2.0, 3.0),), dtype=wp.vec3, shape=(1,))
1949
+ assert_np_equal(result.numpy(), expected.numpy())
1950
+
1951
+ # --------------------------------------------------------------------------
1952
+
1953
+ arr = np.array(((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)), dtype=float)
1954
+
1955
+ result = wp.from_numpy(arr)
1956
+ expected = wp.array(((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)), dtype=wp.vec3, shape=(2,))
1957
+ assert_np_equal(result.numpy(), expected.numpy())
1958
+
1959
+ result = wp.from_numpy(arr, dtype=wp.float32)
1960
+ expected = wp.array(((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)), dtype=wp.float32, shape=(2, 3))
1961
+ assert_np_equal(result.numpy(), expected.numpy())
1962
+
1963
+ result = wp.from_numpy(arr, dtype=wp.float32, shape=(6,))
1964
+ expected = wp.array((1.0, 2.0, 3.0, 4.0, 5.0, 6.0), dtype=wp.float32, shape=(6,))
1965
+ assert_np_equal(result.numpy(), expected.numpy())
1966
+
1967
+ # --------------------------------------------------------------------------
1968
+
1969
+ arr = np.array(
1970
+ (
1971
+ (
1972
+ (1.0, 2.0, 3.0, 4.0),
1973
+ (2.0, 3.0, 4.0, 5.0),
1974
+ (3.0, 4.0, 5.0, 6.0),
1975
+ (4.0, 5.0, 6.0, 7.0),
1976
+ ),
1977
+ (
1978
+ (2.0, 3.0, 4.0, 5.0),
1979
+ (3.0, 4.0, 5.0, 6.0),
1980
+ (4.0, 5.0, 6.0, 7.0),
1981
+ (5.0, 6.0, 7.0, 8.0),
1982
+ ),
1983
+ ),
1984
+ dtype=float,
1985
+ )
1986
+
1987
+ result = wp.from_numpy(arr)
1988
+ expected = wp.array(
1989
+ (
1990
+ (
1991
+ (1.0, 2.0, 3.0, 4.0),
1992
+ (2.0, 3.0, 4.0, 5.0),
1993
+ (3.0, 4.0, 5.0, 6.0),
1994
+ (4.0, 5.0, 6.0, 7.0),
1995
+ ),
1996
+ (
1997
+ (2.0, 3.0, 4.0, 5.0),
1998
+ (3.0, 4.0, 5.0, 6.0),
1999
+ (4.0, 5.0, 6.0, 7.0),
2000
+ (5.0, 6.0, 7.0, 8.0),
2001
+ ),
2002
+ ),
2003
+ dtype=wp.mat44,
2004
+ shape=(2,),
2005
+ )
2006
+ assert_np_equal(result.numpy(), expected.numpy())
2007
+
2008
+ result = wp.from_numpy(arr, dtype=wp.float32)
2009
+ expected = wp.array(
2010
+ (
2011
+ (
2012
+ (1.0, 2.0, 3.0, 4.0),
2013
+ (2.0, 3.0, 4.0, 5.0),
2014
+ (3.0, 4.0, 5.0, 6.0),
2015
+ (4.0, 5.0, 6.0, 7.0),
2016
+ ),
2017
+ (
2018
+ (2.0, 3.0, 4.0, 5.0),
2019
+ (3.0, 4.0, 5.0, 6.0),
2020
+ (4.0, 5.0, 6.0, 7.0),
2021
+ (5.0, 6.0, 7.0, 8.0),
2022
+ ),
2023
+ ),
2024
+ dtype=wp.float32,
2025
+ shape=(2, 4, 4),
2026
+ )
2027
+ assert_np_equal(result.numpy(), expected.numpy())
2028
+
2029
+ result = wp.from_numpy(arr, dtype=wp.vec4)
2030
+ expected = wp.array(
2031
+ (
2032
+ (1.0, 2.0, 3.0, 4.0),
2033
+ (2.0, 3.0, 4.0, 5.0),
2034
+ (3.0, 4.0, 5.0, 6.0),
2035
+ (4.0, 5.0, 6.0, 7.0),
2036
+ (2.0, 3.0, 4.0, 5.0),
2037
+ (3.0, 4.0, 5.0, 6.0),
2038
+ (4.0, 5.0, 6.0, 7.0),
2039
+ (5.0, 6.0, 7.0, 8.0),
2040
+ ),
2041
+ dtype=wp.vec4,
2042
+ shape=(8,),
2043
+ )
2044
+ assert_np_equal(result.numpy(), expected.numpy())
2045
+
2046
+ result = wp.from_numpy(arr, dtype=wp.float32, shape=(32,))
2047
+ expected = wp.array(
2048
+ (
2049
+ 1.0,
2050
+ 2.0,
2051
+ 3.0,
2052
+ 4.0,
2053
+ 2.0,
2054
+ 3.0,
2055
+ 4.0,
2056
+ 5.0,
2057
+ 3.0,
2058
+ 4.0,
2059
+ 5.0,
2060
+ 6.0,
2061
+ 4.0,
2062
+ 5.0,
2063
+ 6.0,
2064
+ 7.0,
2065
+ 2.0,
2066
+ 3.0,
2067
+ 4.0,
2068
+ 5.0,
2069
+ 3.0,
2070
+ 4.0,
2071
+ 5.0,
2072
+ 6.0,
2073
+ 4.0,
2074
+ 5.0,
2075
+ 6.0,
2076
+ 7.0,
2077
+ 5.0,
2078
+ 6.0,
2079
+ 7.0,
2080
+ 8.0,
2081
+ ),
2082
+ dtype=wp.float32,
2083
+ shape=(32,),
2084
+ )
2085
+ assert_np_equal(result.numpy(), expected.numpy())
2086
+
2087
+
2088
+ devices = get_test_devices()
2089
+
2090
+
2091
+ class TestArray(unittest.TestCase):
2092
+ pass
2093
+
2094
+
2095
+ add_function_test(TestArray, "test_shape", test_shape, devices=devices)
2096
+ add_function_test(TestArray, "test_flatten", test_flatten, devices=devices)
2097
+ add_function_test(TestArray, "test_reshape", test_reshape, devices=devices)
2098
+ add_function_test(TestArray, "test_slicing", test_slicing, devices=devices)
2099
+ add_function_test(TestArray, "test_transpose", test_transpose, devices=devices)
2100
+ add_function_test(TestArray, "test_view", test_view, devices=devices)
2101
+
2102
+ add_function_test(TestArray, "test_1d_array", test_1d, devices=devices)
2103
+ add_function_test(TestArray, "test_2d_array", test_2d, devices=devices)
2104
+ add_function_test(TestArray, "test_3d_array", test_3d, devices=devices)
2105
+ add_function_test(TestArray, "test_4d_array", test_4d, devices=devices)
2106
+ add_function_test(TestArray, "test_4d_array_transposed", test_4d_transposed, devices=devices)
2107
+
2108
+ add_function_test(TestArray, "test_fill_scalar", test_fill_scalar, devices=devices)
2109
+ add_function_test(TestArray, "test_fill_vector", test_fill_vector, devices=devices)
2110
+ add_function_test(TestArray, "test_fill_matrix", test_fill_matrix, devices=devices)
2111
+ add_function_test(TestArray, "test_fill_struct", test_fill_struct, devices=devices)
2112
+ add_function_test(TestArray, "test_fill_slices", test_fill_slices, devices=devices)
2113
+ add_function_test(TestArray, "test_full_scalar", test_full_scalar, devices=devices)
2114
+ add_function_test(TestArray, "test_full_vector", test_full_vector, devices=devices)
2115
+ add_function_test(TestArray, "test_full_matrix", test_full_matrix, devices=devices)
2116
+ add_function_test(TestArray, "test_full_struct", test_full_struct, devices=devices)
2117
+ add_function_test(TestArray, "test_empty_array", test_empty_array, devices=devices)
2118
+ add_function_test(TestArray, "test_empty_from_numpy", test_empty_from_numpy, devices=devices)
2119
+ add_function_test(TestArray, "test_empty_from_list", test_empty_from_list, devices=devices)
2120
+ add_function_test(TestArray, "test_to_list_scalar", test_to_list_scalar, devices=devices)
2121
+ add_function_test(TestArray, "test_to_list_vector", test_to_list_vector, devices=devices)
2122
+ add_function_test(TestArray, "test_to_list_matrix", test_to_list_matrix, devices=devices)
2123
+ add_function_test(TestArray, "test_to_list_struct", test_to_list_struct, devices=devices)
2124
+
2125
+ add_function_test(TestArray, "test_lower_bound", test_lower_bound, devices=devices)
2126
+ add_function_test(TestArray, "test_round_trip", test_round_trip, devices=devices)
2127
+ add_function_test(TestArray, "test_array_to_bool", test_array_to_bool, devices=devices)
2128
+ add_function_test(TestArray, "test_array_of_structs", test_array_of_structs, devices=devices)
2129
+ add_function_test(TestArray, "test_array_of_structs_grad", test_array_of_structs_grad, devices=devices)
2130
+ add_function_test(TestArray, "test_array_of_structs_from_numpy", test_array_of_structs_from_numpy, devices=devices)
2131
+ add_function_test(TestArray, "test_array_of_structs_roundtrip", test_array_of_structs_roundtrip, devices=devices)
2132
+ add_function_test(TestArray, "test_array_from_numpy", test_array_from_numpy, devices=devices)
2018
2133
 
2019
2134
 
2020
2135
  if __name__ == "__main__":
2021
- c = register(unittest.TestCase)
2136
+ wp.build.clear_kernel_cache()
2022
2137
  unittest.main(verbosity=2)
@@ -1,8 +1,17 @@
1
+ # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
1
10
  import numpy as np
2
- import warp as wp
3
11
 
4
- from warp.utils import array_sum, array_inner
5
- from warp.tests.test_base import *
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
14
+ from warp.utils import array_inner, array_sum
6
15
 
7
16
  wp.init()
8
17
 
@@ -11,9 +20,11 @@ def make_test_array_sum(dtype):
11
20
  N = 1000
12
21
 
13
22
  def test_array_sum(test, device):
23
+ rng = np.random.default_rng(123)
24
+
14
25
  cols = wp.types.type_length(dtype)
15
26
 
16
- values_np = np.random.rand(N, cols)
27
+ values_np = rng.random(size=(N, cols))
17
28
  values = wp.array(values_np, device=device, dtype=dtype)
18
29
 
19
30
  vsum = array_sum(values)
@@ -32,7 +43,9 @@ def make_test_array_sum_axis(dtype):
32
43
  N = I * J * K
33
44
 
34
45
  def test_array_sum(test, device):
35
- values_np = np.random.rand(I, J, K)
46
+ rng = np.random.default_rng(123)
47
+
48
+ values_np = rng.random(size=(I, J, K))
36
49
  values = wp.array(values_np, shape=(I, J, K), device=device, dtype=dtype)
37
50
 
38
51
  for axis in range(3):
@@ -44,14 +57,24 @@ def make_test_array_sum_axis(dtype):
44
57
  return test_array_sum
45
58
 
46
59
 
60
+ def test_array_sum_empty(test, device):
61
+ values = wp.array([], device=device, dtype=wp.vec2)
62
+ assert_np_equal(array_sum(values), np.zeros(2))
63
+
64
+ values = wp.array([], shape=(0, 3), device=device, dtype=float)
65
+ assert_np_equal(array_sum(values, axis=0).numpy(), np.zeros(3))
66
+
67
+
47
68
  def make_test_array_inner(dtype):
48
69
  N = 1000
49
70
 
50
71
  def test_array_inner(test, device):
72
+ rng = np.random.default_rng(123)
73
+
51
74
  cols = wp.types.type_length(dtype)
52
75
 
53
- a_np = np.random.rand(N, cols)
54
- b_np = np.random.rand(N, cols)
76
+ a_np = rng.random(size=(N, cols))
77
+ b_np = rng.random(size=(N, cols))
55
78
 
56
79
  a = wp.array(a_np, device=device, dtype=dtype)
57
80
  b = wp.array(b_np, device=device, dtype=dtype)
@@ -72,8 +95,10 @@ def make_test_array_inner_axis(dtype):
72
95
  N = I * J * K
73
96
 
74
97
  def test_array_inner(test, device):
75
- a_np = np.random.rand(I, J, K)
76
- b_np = np.random.rand(I, J, K)
98
+ rng = np.random.default_rng(123)
99
+
100
+ a_np = rng.random(size=(I, J, K))
101
+ b_np = rng.random(size=(I, J, K))
77
102
 
78
103
  a = wp.array(a_np, shape=(I, J, K), device=device, dtype=dtype)
79
104
  b = wp.array(b_np, shape=(I, J, K), device=device, dtype=dtype)
@@ -93,24 +118,33 @@ def make_test_array_inner_axis(dtype):
93
118
  return test_array_inner
94
119
 
95
120
 
96
- def register(parent):
97
- devices = get_test_devices()
121
+ def test_array_inner_empty(test, device):
122
+ values = wp.array([], device=device, dtype=wp.vec2)
123
+ test.assertEqual(array_inner(values, values), 0.0)
124
+
125
+ values = wp.array([], shape=(0, 3), device=device, dtype=float)
126
+ assert_np_equal(array_inner(values, values, axis=0).numpy(), np.zeros(3))
127
+
128
+
129
+ devices = get_test_devices()
130
+
98
131
 
99
- class TestArraySym(parent):
100
- pass
132
+ class TestArrayReduce(unittest.TestCase):
133
+ pass
101
134
 
102
- add_function_test(TestArraySym, "test_array_sum_double", make_test_array_sum(wp.float64), devices=devices)
103
- add_function_test(TestArraySym, "test_array_sum_vec3", make_test_array_sum(wp.vec3), devices=devices)
104
- add_function_test(TestArraySym, "test_array_sum_axis_float", make_test_array_sum_axis(wp.float32), devices=devices)
105
- add_function_test(TestArraySym, "test_array_inner_double", make_test_array_inner(wp.float64), devices=devices)
106
- add_function_test(TestArraySym, "test_array_inner_vec3", make_test_array_inner(wp.vec3), devices=devices)
107
- add_function_test(
108
- TestArraySym, "test_array_inner_axis_float", make_test_array_inner_axis(wp.float32), devices=devices
109
- )
110
135
 
111
- return TestArraySym
136
+ add_function_test(TestArrayReduce, "test_array_sum_double", make_test_array_sum(wp.float64), devices=devices)
137
+ add_function_test(TestArrayReduce, "test_array_sum_vec3", make_test_array_sum(wp.vec3), devices=devices)
138
+ add_function_test(TestArrayReduce, "test_array_sum_axis_float", make_test_array_sum_axis(wp.float32), devices=devices)
139
+ add_function_test(TestArrayReduce, "test_array_sum_empty", test_array_sum_empty, devices=devices)
140
+ add_function_test(TestArrayReduce, "test_array_inner_double", make_test_array_inner(wp.float64), devices=devices)
141
+ add_function_test(TestArrayReduce, "test_array_inner_vec3", make_test_array_inner(wp.vec3), devices=devices)
142
+ add_function_test(
143
+ TestArrayReduce, "test_array_inner_axis_float", make_test_array_inner_axis(wp.float32), devices=devices
144
+ )
145
+ add_function_test(TestArrayReduce, "test_array_inner_empty", test_array_inner_empty, devices=devices)
112
146
 
113
147
 
114
148
  if __name__ == "__main__":
115
- c = register(unittest.TestCase)
149
+ wp.build.clear_kernel_cache()
116
150
  unittest.main(verbosity=2)
warp/tests/test_atomic.py CHANGED
@@ -5,14 +5,12 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- # include parent path
8
+ import unittest
9
+
9
10
  import numpy as np
10
- import math
11
11
 
12
12
  import warp as wp
13
- from warp.tests.test_base import *
14
-
15
- import unittest
13
+ from warp.tests.unittest_utils import *
16
14
 
17
15
  wp.init()
18
16
 
@@ -34,8 +32,7 @@ def make_atomic_test(type):
34
32
  # register a custom kernel (no decorator) function
35
33
  # this lets us register the same function definition
36
34
  # against multiple symbols, with different arg types
37
- module = wp.get_module(test_atomic_kernel.__module__)
38
- kernel = wp.Kernel(func=test_atomic_kernel, key=f"test_atomic_{type.__name__}_kernel", module=module)
35
+ kernel = wp.Kernel(func=test_atomic_kernel, key=f"test_atomic_{type.__name__}_kernel")
39
36
 
40
37
  def test_atomic(test, device):
41
38
  n = 1024
@@ -54,20 +51,60 @@ def make_atomic_test(type):
54
51
  base = rng.random(size=(1, *type._shape_), dtype=float)
55
52
  val = rng.random(size=(n, *type._shape_), dtype=float)
56
53
 
57
- add_array = wp.array(base, dtype=type, device=device)
58
- min_array = wp.array(base, dtype=type, device=device)
59
- max_array = wp.array(base, dtype=type, device=device)
54
+ add_array = wp.array(base, dtype=type, device=device, requires_grad=True)
55
+ min_array = wp.array(base, dtype=type, device=device, requires_grad=True)
56
+ max_array = wp.array(base, dtype=type, device=device, requires_grad=True)
57
+ add_array.zero_()
58
+ min_array.fill_(10000)
59
+ max_array.fill_(-10000)
60
60
 
61
- val_array = wp.array(val, dtype=type, device=device)
61
+ val_array = wp.array(val, dtype=type, device=device, requires_grad=True)
62
62
 
63
- wp.launch(kernel, n, inputs=[add_array, min_array, max_array, val_array], device=device)
64
-
65
- val = np.append(val, [base[0]], axis=0)
63
+ tape = wp.Tape()
64
+ with tape:
65
+ wp.launch(kernel, n, inputs=[add_array, min_array, max_array, val_array], device=device)
66
66
 
67
67
  assert_np_equal(add_array.numpy(), np.sum(val, axis=0), tol=1.0e-2)
68
68
  assert_np_equal(min_array.numpy(), np.min(val, axis=0), tol=1.0e-2)
69
69
  assert_np_equal(max_array.numpy(), np.max(val, axis=0), tol=1.0e-2)
70
70
 
71
+ if type != wp.int32:
72
+ add_array.grad.fill_(1)
73
+ tape.backward()
74
+ assert_np_equal(val_array.grad.numpy(), np.ones_like(val))
75
+ tape.zero()
76
+
77
+ min_array.grad.fill_(1)
78
+ tape.backward()
79
+ min_grad_array = np.zeros_like(val)
80
+ argmin = val.argmin(axis=0)
81
+ if val.ndim == 1:
82
+ min_grad_array[argmin] = 1
83
+ elif val.ndim == 2:
84
+ for i in range(val.shape[1]):
85
+ min_grad_array[argmin[i], i] = 1
86
+ elif val.ndim == 3:
87
+ for i in range(val.shape[1]):
88
+ for j in range(val.shape[2]):
89
+ min_grad_array[argmin[i, j], i, j] = 1
90
+ assert_np_equal(val_array.grad.numpy(), min_grad_array)
91
+ tape.zero()
92
+
93
+ max_array.grad.fill_(1)
94
+ tape.backward()
95
+ max_grad_array = np.zeros_like(val)
96
+ argmax = val.argmax(axis=0)
97
+ if val.ndim == 1:
98
+ max_grad_array[argmax] = 1
99
+ elif val.ndim == 2:
100
+ for i in range(val.shape[1]):
101
+ max_grad_array[argmax[i], i] = 1
102
+ elif val.ndim == 3:
103
+ for i in range(val.shape[1]):
104
+ for j in range(val.shape[2]):
105
+ max_grad_array[argmax[i, j], i, j] = 1
106
+ assert_np_equal(val_array.grad.numpy(), max_grad_array)
107
+
71
108
  return test_atomic
72
109
 
73
110
 
@@ -82,24 +119,23 @@ test_atomic_mat33 = make_atomic_test(wp.mat33)
82
119
  test_atomic_mat44 = make_atomic_test(wp.mat44)
83
120
 
84
121
 
85
- def register(parent):
86
- devices = get_test_devices()
122
+ devices = get_test_devices()
123
+
87
124
 
88
- class TestAtomic(parent):
89
- pass
125
+ class TestAtomic(unittest.TestCase):
126
+ pass
90
127
 
91
- add_function_test(TestAtomic, "test_atomic_int", test_atomic_int, devices=devices)
92
- add_function_test(TestAtomic, "test_atomic_float", test_atomic_float, devices=devices)
93
- add_function_test(TestAtomic, "test_atomic_vec2", test_atomic_vec2, devices=devices)
94
- add_function_test(TestAtomic, "test_atomic_vec3", test_atomic_vec3, devices=devices)
95
- add_function_test(TestAtomic, "test_atomic_vec4", test_atomic_vec4, devices=devices)
96
- add_function_test(TestAtomic, "test_atomic_mat22", test_atomic_mat22, devices=devices)
97
- add_function_test(TestAtomic, "test_atomic_mat33", test_atomic_mat33, devices=devices)
98
- add_function_test(TestAtomic, "test_atomic_mat44", test_atomic_mat44, devices=devices)
99
128
 
100
- return TestAtomic
129
+ add_function_test(TestAtomic, "test_atomic_int", test_atomic_int, devices=devices)
130
+ add_function_test(TestAtomic, "test_atomic_float", test_atomic_float, devices=devices)
131
+ add_function_test(TestAtomic, "test_atomic_vec2", test_atomic_vec2, devices=devices)
132
+ add_function_test(TestAtomic, "test_atomic_vec3", test_atomic_vec3, devices=devices)
133
+ add_function_test(TestAtomic, "test_atomic_vec4", test_atomic_vec4, devices=devices)
134
+ add_function_test(TestAtomic, "test_atomic_mat22", test_atomic_mat22, devices=devices)
135
+ add_function_test(TestAtomic, "test_atomic_mat33", test_atomic_mat33, devices=devices)
136
+ add_function_test(TestAtomic, "test_atomic_mat44", test_atomic_mat44, devices=devices)
101
137
 
102
138
 
103
139
  if __name__ == "__main__":
104
- c = register(unittest.TestCase)
140
+ wp.build.clear_kernel_cache()
105
141
  unittest.main(verbosity=2)