warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/tests/test_adam.py CHANGED
@@ -5,16 +5,14 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- # include parent path
8
+ import unittest
9
+
9
10
  import numpy as np
10
- import math
11
11
 
12
12
  import warp as wp
13
- from warp.tests.test_base import *
14
- import unittest
15
-
16
13
  import warp.optim
17
14
  import warp.sim
15
+ from warp.tests.unittest_utils import *
18
16
 
19
17
  wp.init()
20
18
 
@@ -28,32 +26,32 @@ def objective(params: wp.array(dtype=float), score: wp.array(dtype=float)):
28
26
 
29
27
  # This test inspired by https://machinelearningmastery.com/adam-optimization-from-scratch/
30
28
  def test_adam_solve_float(test, device):
31
- wp.set_device(device)
32
- params_start = np.array([0.1, 0.2], dtype=float)
33
- score = wp.zeros(1, dtype=float, requires_grad=True)
34
- params = wp.array(params_start, dtype=float, requires_grad=True)
35
- tape = wp.Tape()
36
- opt = warp.optim.Adam([params], lr=0.02, betas=(0.8, 0.999))
37
-
38
- def gradient_func():
39
- tape.reset()
40
- score.zero_()
41
- with tape:
42
- wp.launch(kernel=objective, dim=len(params), inputs=[params, score])
43
- tape.backward(score)
44
- return [tape.gradients[params]]
45
-
46
- niters = 100
47
-
48
- opt.reset_internal_state()
49
- for _ in range(niters):
50
- opt.step(gradient_func())
51
-
52
- result = params.numpy()
53
- # optimum is at the origin, so the result should be close to it in all N dimensions.
54
- tol = 1e-5
55
- for r in result:
56
- test.assertLessEqual(r, tol)
29
+ with wp.ScopedDevice(device):
30
+ params_start = np.array([0.1, 0.2], dtype=float)
31
+ score = wp.zeros(1, dtype=float, requires_grad=True)
32
+ params = wp.array(params_start, dtype=float, requires_grad=True)
33
+ tape = wp.Tape()
34
+ opt = warp.optim.Adam([params], lr=0.02, betas=(0.8, 0.999))
35
+
36
+ def gradient_func():
37
+ tape.reset()
38
+ score.zero_()
39
+ with tape:
40
+ wp.launch(kernel=objective, dim=len(params), inputs=[params, score])
41
+ tape.backward(score)
42
+ return [tape.gradients[params]]
43
+
44
+ niters = 100
45
+
46
+ opt.reset_internal_state()
47
+ for _ in range(niters):
48
+ opt.step(gradient_func())
49
+
50
+ result = params.numpy()
51
+ # optimum is at the origin, so the result should be close to it in all N dimensions.
52
+ tol = 1e-5
53
+ for r in result:
54
+ test.assertLessEqual(r, tol)
57
55
 
58
56
 
59
57
  @wp.kernel
@@ -65,32 +63,32 @@ def objective_vec3(params: wp.array(dtype=wp.vec3), score: wp.array(dtype=float)
65
63
 
66
64
  # This test inspired by https://machinelearningmastery.com/adam-optimization-from-scratch/
67
65
  def test_adam_solve_vec3(test, device):
68
- wp.set_device(device)
69
- params_start = np.array([[0.1, 0.2, -0.1]], dtype=float)
70
- score = wp.zeros(1, dtype=float, requires_grad=True)
71
- params = wp.array(params_start, dtype=wp.vec3, requires_grad=True)
72
- tape = wp.Tape()
73
- opt = warp.optim.Adam([params], lr=0.02, betas=(0.8, 0.999))
74
-
75
- def gradient_func():
76
- tape.reset()
77
- score.zero_()
78
- with tape:
79
- wp.launch(kernel=objective_vec3, dim=len(params), inputs=[params, score])
80
- tape.backward(score)
81
- return [tape.gradients[params]]
82
-
83
- niters = 100
84
- opt.reset_internal_state()
85
- for _ in range(niters):
86
- opt.step(gradient_func())
87
-
88
- result = params.numpy()
89
- tol = 1e-5
90
- # optimum is at the origin, so the result should be close to it in all N dimensions.
91
- for r in result:
92
- for v in r:
93
- test.assertLessEqual(v, tol)
66
+ with wp.ScopedDevice(device):
67
+ params_start = np.array([[0.1, 0.2, -0.1]], dtype=float)
68
+ score = wp.zeros(1, dtype=float, requires_grad=True)
69
+ params = wp.array(params_start, dtype=wp.vec3, requires_grad=True)
70
+ tape = wp.Tape()
71
+ opt = warp.optim.Adam([params], lr=0.02, betas=(0.8, 0.999))
72
+
73
+ def gradient_func():
74
+ tape.reset()
75
+ score.zero_()
76
+ with tape:
77
+ wp.launch(kernel=objective_vec3, dim=len(params), inputs=[params, score])
78
+ tape.backward(score)
79
+ return [tape.gradients[params]]
80
+
81
+ niters = 100
82
+ opt.reset_internal_state()
83
+ for _ in range(niters):
84
+ opt.step(gradient_func())
85
+
86
+ result = params.numpy()
87
+ tol = 1e-5
88
+ # optimum is at the origin, so the result should be close to it in all N dimensions.
89
+ for r in result:
90
+ for v in r:
91
+ test.assertLessEqual(v, tol)
94
92
 
95
93
 
96
94
  @wp.kernel
@@ -105,56 +103,55 @@ def objective_two_inputs_vec3(
105
103
 
106
104
  # This test inspired by https://machinelearningmastery.com/adam-optimization-from-scratch/
107
105
  def test_adam_solve_two_inputs(test, device):
108
- wp.set_device(device)
109
- params_start1 = np.array([[0.1, 0.2, -0.1]], dtype=float)
110
- params_start2 = np.array([[0.2, 0.1, 0.1]], dtype=float)
111
- score = wp.zeros(1, dtype=float, requires_grad=True)
112
- params1 = wp.array(params_start1, dtype=wp.vec3, requires_grad=True)
113
- params2 = wp.array(params_start2, dtype=wp.vec3, requires_grad=True)
114
- tape = wp.Tape()
115
- opt = warp.optim.Adam([params1, params2], lr=0.02, betas=(0.8, 0.999))
116
-
117
- def gradient_func():
118
- tape.reset()
119
- score.zero_()
120
- with tape:
121
- wp.launch(kernel=objective_two_inputs_vec3, dim=len(params1), inputs=[params1, params2, score])
122
- tape.backward(score)
123
- return [tape.gradients[params1], tape.gradients[params2]]
124
-
125
- niters = 100
126
- opt.reset_internal_state()
127
- for _ in range(niters):
128
- opt.step(gradient_func())
129
-
130
- result = params1.numpy()
131
- tol = 1e-5
132
- # optimum is at the origin, so the result should be close to it in all N dimensions.
133
- for r in result:
134
- for v in r:
135
- test.assertLessEqual(v, tol)
136
-
137
- result = params2.numpy()
138
- tol = 1e-5
139
- # optimum is at the origin, so the result should be close to it in all N dimensions.
140
- for r in result:
141
- for v in r:
142
- test.assertLessEqual(v, tol)
143
-
144
-
145
- def register(parent):
146
- devices = get_test_devices()
147
-
148
- class TestArray(parent):
149
- pass
150
-
151
- add_function_test(TestArray, "test_adam_solve_float", test_adam_solve_float, devices=devices)
152
- add_function_test(TestArray, "test_adam_solve_vec3", test_adam_solve_vec3, devices=devices)
153
- add_function_test(TestArray, "test_adam_solve_two_inputs", test_adam_solve_two_inputs, devices=devices)
154
-
155
- return TestArray
106
+ with wp.ScopedDevice(device):
107
+ params_start1 = np.array([[0.1, 0.2, -0.1]], dtype=float)
108
+ params_start2 = np.array([[0.2, 0.1, 0.1]], dtype=float)
109
+ score = wp.zeros(1, dtype=float, requires_grad=True)
110
+ params1 = wp.array(params_start1, dtype=wp.vec3, requires_grad=True)
111
+ params2 = wp.array(params_start2, dtype=wp.vec3, requires_grad=True)
112
+ tape = wp.Tape()
113
+ opt = warp.optim.Adam([params1, params2], lr=0.02, betas=(0.8, 0.999))
114
+
115
+ def gradient_func():
116
+ tape.reset()
117
+ score.zero_()
118
+ with tape:
119
+ wp.launch(kernel=objective_two_inputs_vec3, dim=len(params1), inputs=[params1, params2, score])
120
+ tape.backward(score)
121
+ return [tape.gradients[params1], tape.gradients[params2]]
122
+
123
+ niters = 100
124
+ opt.reset_internal_state()
125
+ for _ in range(niters):
126
+ opt.step(gradient_func())
127
+
128
+ result = params1.numpy()
129
+ tol = 1e-5
130
+ # optimum is at the origin, so the result should be close to it in all N dimensions.
131
+ for r in result:
132
+ for v in r:
133
+ test.assertLessEqual(v, tol)
134
+
135
+ result = params2.numpy()
136
+ tol = 1e-5
137
+ # optimum is at the origin, so the result should be close to it in all N dimensions.
138
+ for r in result:
139
+ for v in r:
140
+ test.assertLessEqual(v, tol)
141
+
142
+
143
+ devices = get_test_devices()
144
+
145
+
146
+ class TestAdam(unittest.TestCase):
147
+ pass
148
+
149
+
150
+ add_function_test(TestAdam, "test_adam_solve_float", test_adam_solve_float, devices=devices)
151
+ add_function_test(TestAdam, "test_adam_solve_vec3", test_adam_solve_vec3, devices=devices)
152
+ add_function_test(TestAdam, "test_adam_solve_two_inputs", test_adam_solve_two_inputs, devices=devices)
156
153
 
157
154
 
158
155
  if __name__ == "__main__":
159
- c = register(unittest.TestCase)
156
+ wp.build.clear_kernel_cache()
160
157
  unittest.main(verbosity=2)
@@ -5,9 +5,13 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import math
9
+ import unittest
10
+
8
11
  import numpy as np
12
+
9
13
  import warp as wp
10
- from warp.tests.test_base import *
14
+ from warp.tests.unittest_utils import *
11
15
 
12
16
  wp.init()
13
17
 
@@ -34,22 +38,21 @@ np_float_types = [np.float16, np.float32, np.float64]
34
38
  np_scalar_types = np_int_types + np_float_types
35
39
 
36
40
 
37
- def randvals(shape, dtype):
41
+ def randvals(rng, shape, dtype):
38
42
  if dtype in np_float_types:
39
- return np.random.randn(*shape).astype(dtype)
43
+ return rng.standard_normal(size=shape).astype(dtype)
40
44
  elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
41
- return np.random.randint(1, 3, size=shape, dtype=dtype)
42
- return np.random.randint(1, 5, size=shape, dtype=dtype)
45
+ return rng.integers(1, high=3, size=shape, dtype=dtype)
46
+ return rng.integers(1, high=5, size=shape, dtype=dtype)
43
47
 
44
48
 
45
49
  kernel_cache = dict()
46
50
 
47
51
 
48
52
  def getkernel(func, suffix=""):
49
- module = wp.get_module(func.__module__)
50
53
  key = func.__name__ + "_" + suffix
51
54
  if key not in kernel_cache:
52
- kernel_cache[key] = wp.Kernel(func=func, key=key, module=module)
55
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
53
56
  return kernel_cache[key]
54
57
 
55
58
 
@@ -77,7 +80,7 @@ def get_select_kernel2(dtype):
77
80
 
78
81
 
79
82
  def test_arrays(test, device, dtype):
80
- np.random.seed(123)
83
+ rng = np.random.default_rng(123)
81
84
 
82
85
  tol = {
83
86
  np.float16: 1.0e-3,
@@ -86,14 +89,14 @@ def test_arrays(test, device, dtype):
86
89
  }.get(dtype, 0)
87
90
 
88
91
  wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
89
- arr_np = randvals((10, 5), dtype)
92
+ arr_np = randvals(rng, (10, 5), dtype)
90
93
  arr = wp.array(arr_np, dtype=wptype, requires_grad=True, device=device)
91
94
 
92
95
  assert_np_equal(arr.numpy(), arr_np, tol=tol)
93
96
 
94
97
 
95
98
  def test_unary_ops(test, device, dtype, register_kernels=False):
96
- np.random.seed(123)
99
+ rng = np.random.default_rng(123)
97
100
 
98
101
  tol = {
99
102
  np.float16: 5.0e-3,
@@ -128,10 +131,12 @@ def test_unary_ops(test, device, dtype, register_kernels=False):
128
131
  return
129
132
 
130
133
  if dtype in np_float_types:
131
- inputs = wp.array(np.random.randn(5, 10).astype(dtype), dtype=wptype, requires_grad=True, device=device)
134
+ inputs = wp.array(
135
+ rng.standard_normal(size=(5, 10)).astype(dtype), dtype=wptype, requires_grad=True, device=device
136
+ )
132
137
  else:
133
138
  inputs = wp.array(
134
- np.random.randint(-2, 3, size=(5, 10), dtype=dtype), dtype=wptype, requires_grad=True, device=device
139
+ rng.integers(-2, high=3, size=(5, 10), dtype=dtype), dtype=wptype, requires_grad=True, device=device
135
140
  )
136
141
  outputs = wp.zeros_like(inputs)
137
142
 
@@ -207,7 +212,7 @@ def test_unary_ops(test, device, dtype, register_kernels=False):
207
212
 
208
213
 
209
214
  def test_nonzero(test, device, dtype, register_kernels=False):
210
- np.random.seed(123)
215
+ rng = np.random.default_rng(123)
211
216
 
212
217
  tol = {
213
218
  np.float16: 5.0e-3,
@@ -231,7 +236,7 @@ def test_nonzero(test, device, dtype, register_kernels=False):
231
236
  if register_kernels:
232
237
  return
233
238
 
234
- inputs = wp.array(np.random.randint(-2, 3, size=10).astype(dtype), dtype=wptype, requires_grad=True, device=device)
239
+ inputs = wp.array(rng.integers(-2, high=3, size=10).astype(dtype), dtype=wptype, requires_grad=True, device=device)
235
240
  outputs = wp.zeros_like(inputs)
236
241
 
237
242
  wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
@@ -253,10 +258,10 @@ def test_nonzero(test, device, dtype, register_kernels=False):
253
258
 
254
259
 
255
260
  def test_binary_ops(test, device, dtype, register_kernels=False):
256
- np.random.seed(123)
261
+ rng = np.random.default_rng(123)
257
262
 
258
263
  tol = {
259
- np.float16: 1.0e-2,
264
+ np.float16: 5.0e-2,
260
265
  np.float32: 1.0e-6,
261
266
  np.float64: 1.0e-8,
262
267
  }.get(dtype, 0)
@@ -302,11 +307,11 @@ def test_binary_ops(test, device, dtype, register_kernels=False):
302
307
  if register_kernels:
303
308
  return
304
309
 
305
- vals1 = randvals([8, 10], dtype)
310
+ vals1 = randvals(rng, [8, 10], dtype)
306
311
  if dtype in [np_unsigned_int_types]:
307
- vals2 = vals1 + randvals([8, 10], dtype)
312
+ vals2 = vals1 + randvals(rng, [8, 10], dtype)
308
313
  else:
309
- vals2 = np.abs(randvals([8, 10], dtype))
314
+ vals2 = np.abs(randvals(rng, [8, 10], dtype))
310
315
 
311
316
  in1 = wp.array(vals1, dtype=wptype, requires_grad=True, device=device)
312
317
  in2 = wp.array(vals2, dtype=wptype, requires_grad=True, device=device)
@@ -458,7 +463,7 @@ def test_binary_ops(test, device, dtype, register_kernels=False):
458
463
 
459
464
 
460
465
  def test_special_funcs(test, device, dtype, register_kernels=False):
461
- np.random.seed(123)
466
+ rng = np.random.default_rng(123)
462
467
 
463
468
  tol = {
464
469
  np.float16: 1.0e-2,
@@ -488,6 +493,7 @@ def test_special_funcs(test, device, dtype, register_kernels=False):
488
493
  outputs[11, i] = wptype(2) * wp.tanh(inputs[11, i])
489
494
  outputs[12, i] = wptype(2) * wp.acos(inputs[12, i])
490
495
  outputs[13, i] = wptype(2) * wp.asin(inputs[13, i])
496
+ outputs[14, i] = wptype(2) * wp.cbrt(inputs[14, i])
491
497
 
492
498
  kernel = getkernel(check_special_funcs, suffix=dtype.__name__)
493
499
  output_select_kernel = get_select_kernel2(wptype)
@@ -495,8 +501,8 @@ def test_special_funcs(test, device, dtype, register_kernels=False):
495
501
  if register_kernels:
496
502
  return
497
503
 
498
- invals = np.random.randn(14, 10).astype(dtype)
499
- invals[[0, 1, 2, 7]] = 0.1 + np.abs(invals[[0, 1, 2, 7]])
504
+ invals = rng.normal(size=(15, 10)).astype(dtype)
505
+ invals[[0, 1, 2, 7, 14]] = 0.1 + np.abs(invals[[0, 1, 2, 7, 14]])
500
506
  invals[12] = np.clip(invals[12], -0.9, 0.9)
501
507
  invals[13] = np.clip(invals[13], -0.9, 0.9)
502
508
  inputs = wp.array(invals, dtype=wptype, requires_grad=True, device=device)
@@ -518,6 +524,7 @@ def test_special_funcs(test, device, dtype, register_kernels=False):
518
524
  assert_np_equal(outputs.numpy()[11], 2 * np.tanh(inputs.numpy()[11]), tol=tol)
519
525
  assert_np_equal(outputs.numpy()[12], 2 * np.arccos(inputs.numpy()[12]), tol=tol)
520
526
  assert_np_equal(outputs.numpy()[13], 2 * np.arcsin(inputs.numpy()[13]), tol=tol)
527
+ assert_np_equal(outputs.numpy()[14], 2 * np.cbrt(inputs.numpy()[14]), tol=tol)
521
528
 
522
529
  out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
523
530
  if dtype in np_float_types:
@@ -694,9 +701,22 @@ def test_special_funcs(test, device, dtype, register_kernels=False):
694
701
  assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=6 * tol)
695
702
  tape.zero()
696
703
 
704
+ # cbrt:
705
+ tape = wp.Tape()
706
+ with tape:
707
+ wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
708
+ wp.launch(output_select_kernel, dim=1, inputs=[outputs, 14, i], outputs=[out], device=device)
709
+
710
+ tape.backward(loss=out)
711
+ expected = np.zeros_like(inputs.numpy())
712
+ cbrt = np.cbrt(inputs.numpy()[14, i], dtype=np.dtype(dtype))
713
+ expected[14, i] = (2.0 / 3.0) * (1.0 / (cbrt * cbrt))
714
+ assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
715
+ tape.zero()
716
+
697
717
 
698
718
  def test_special_funcs_2arg(test, device, dtype, register_kernels=False):
699
- np.random.seed(123)
719
+ rng = np.random.default_rng(123)
700
720
 
701
721
  tol = {
702
722
  np.float16: 1.0e-2,
@@ -722,8 +742,8 @@ def test_special_funcs_2arg(test, device, dtype, register_kernels=False):
722
742
  if register_kernels:
723
743
  return
724
744
 
725
- in1 = wp.array(np.abs(randvals([2, 10], dtype)), dtype=wptype, requires_grad=True, device=device)
726
- in2 = wp.array(randvals([2, 10], dtype), dtype=wptype, requires_grad=True, device=device)
745
+ in1 = wp.array(np.abs(randvals(rng, [2, 10], dtype)), dtype=wptype, requires_grad=True, device=device)
746
+ in2 = wp.array(randvals(rng, [2, 10], dtype), dtype=wptype, requires_grad=True, device=device)
727
747
  outputs = wp.zeros_like(in1)
728
748
 
729
749
  wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
@@ -763,7 +783,7 @@ def test_special_funcs_2arg(test, device, dtype, register_kernels=False):
763
783
 
764
784
 
765
785
  def test_float_to_int(test, device, dtype, register_kernels=False):
766
- np.random.seed(123)
786
+ rng = np.random.default_rng(123)
767
787
 
768
788
  tol = {
769
789
  np.float16: 5.0e-3,
@@ -783,6 +803,7 @@ def test_float_to_int(test, device, dtype, register_kernels=False):
783
803
  outputs[2, i] = wp.trunc(inputs[2, i])
784
804
  outputs[3, i] = wp.floor(inputs[3, i])
785
805
  outputs[4, i] = wp.ceil(inputs[4, i])
806
+ outputs[5, i] = wp.frac(inputs[5, i])
786
807
 
787
808
  kernel = getkernel(check_float_to_int, suffix=dtype.__name__)
788
809
  output_select_kernel = get_select_kernel2(wptype)
@@ -790,7 +811,7 @@ def test_float_to_int(test, device, dtype, register_kernels=False):
790
811
  if register_kernels:
791
812
  return
792
813
 
793
- inputs = wp.array(np.random.randn(5, 10).astype(dtype), dtype=wptype, requires_grad=True, device=device)
814
+ inputs = wp.array(rng.standard_normal(size=(6, 10)).astype(dtype), dtype=wptype, requires_grad=True, device=device)
794
815
  outputs = wp.zeros_like(inputs)
795
816
 
796
817
  wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
@@ -800,6 +821,7 @@ def test_float_to_int(test, device, dtype, register_kernels=False):
800
821
  assert_np_equal(outputs.numpy()[2], np.trunc(inputs.numpy()[2]))
801
822
  assert_np_equal(outputs.numpy()[3], np.floor(inputs.numpy()[3]))
802
823
  assert_np_equal(outputs.numpy()[4], np.ceil(inputs.numpy()[4]))
824
+ assert_np_equal(outputs.numpy()[5], np.modf(inputs.numpy()[5])[0])
803
825
 
804
826
  # all the gradients should be zero as these functions are piecewise constant:
805
827
 
@@ -816,8 +838,38 @@ def test_float_to_int(test, device, dtype, register_kernels=False):
816
838
  tape.zero()
817
839
 
818
840
 
841
+ def test_infinity(test, device, dtype, register_kernels=False):
842
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
843
+
844
+ def check_infinity(
845
+ outputs: wp.array(dtype=wptype),
846
+ ):
847
+ outputs[0] = wptype(wp.inf)
848
+ outputs[1] = wptype(-wp.inf)
849
+ outputs[2] = wptype(2.0 * wp.inf)
850
+ outputs[3] = wptype(-2.0 * wp.inf)
851
+ outputs[4] = wptype(2.0 / 0.0)
852
+ outputs[5] = wptype(-2.0 / 0.0)
853
+
854
+ kernel = getkernel(check_infinity, suffix=dtype.__name__)
855
+
856
+ if register_kernels:
857
+ return
858
+
859
+ outputs = wp.zeros(6, dtype=wptype, device=device)
860
+
861
+ wp.launch(kernel, dim=1, inputs=[], outputs=[outputs], device=device)
862
+
863
+ test.assertEqual(outputs.numpy()[0], math.inf)
864
+ test.assertEqual(outputs.numpy()[1], -math.inf)
865
+ test.assertEqual(outputs.numpy()[2], math.inf)
866
+ test.assertEqual(outputs.numpy()[3], -math.inf)
867
+ test.assertEqual(outputs.numpy()[4], math.inf)
868
+ test.assertEqual(outputs.numpy()[5], -math.inf)
869
+
870
+
819
871
  def test_interp(test, device, dtype, register_kernels=False):
820
- np.random.seed(123)
872
+ rng = np.random.default_rng(123)
821
873
 
822
874
  tol = {
823
875
  np.float16: 1.0e-2,
@@ -844,11 +896,11 @@ def test_interp(test, device, dtype, register_kernels=False):
844
896
  if register_kernels:
845
897
  return
846
898
 
847
- e0 = randvals([2, 10], dtype)
848
- e1 = e0 + randvals([2, 10], dtype) + 0.1
899
+ e0 = randvals(rng, [2, 10], dtype)
900
+ e1 = e0 + randvals(rng, [2, 10], dtype) + 0.1
849
901
  in1 = wp.array(e0, dtype=wptype, requires_grad=True, device=device)
850
902
  in2 = wp.array(e1, dtype=wptype, requires_grad=True, device=device)
851
- in3 = wp.array(randvals([2, 10], dtype), dtype=wptype, requires_grad=True, device=device)
903
+ in3 = wp.array(randvals(rng, [2, 10], dtype), dtype=wptype, requires_grad=True, device=device)
852
904
 
853
905
  outputs = wp.zeros_like(in1)
854
906
 
@@ -948,7 +1000,7 @@ def test_interp(test, device, dtype, register_kernels=False):
948
1000
 
949
1001
 
950
1002
  def test_clamp(test, device, dtype, register_kernels=False):
951
- np.random.seed(123)
1003
+ rng = np.random.default_rng(123)
952
1004
 
953
1005
  tol = {
954
1006
  np.float16: 5.0e-3,
@@ -974,9 +1026,9 @@ def test_clamp(test, device, dtype, register_kernels=False):
974
1026
  if register_kernels:
975
1027
  return
976
1028
 
977
- in1 = wp.array(randvals([100], dtype), dtype=wptype, requires_grad=True, device=device)
978
- starts = randvals([100], dtype)
979
- diffs = np.abs(randvals([100], dtype))
1029
+ in1 = wp.array(randvals(rng, [100], dtype), dtype=wptype, requires_grad=True, device=device)
1030
+ starts = randvals(rng, [100], dtype)
1031
+ diffs = np.abs(randvals(rng, [100], dtype))
980
1032
  in2 = wp.array(starts, dtype=wptype, requires_grad=True, device=device)
981
1033
  in3 = wp.array(starts + diffs, dtype=wptype, requires_grad=True, device=device)
982
1034
  outputs = wp.zeros_like(in1)
@@ -1020,51 +1072,53 @@ def test_clamp(test, device, dtype, register_kernels=False):
1020
1072
  tape.zero()
1021
1073
 
1022
1074
 
1023
- def register(parent):
1024
- devices = get_test_devices()
1075
+ devices = get_test_devices()
1025
1076
 
1026
- class TestArithmetic(parent):
1027
- pass
1028
1077
 
1029
- # these unary ops only make sense for signed values:
1030
- for dtype in np_signed_int_types + np_float_types:
1031
- add_function_test_register_kernel(
1032
- TestArithmetic, f"test_unary_ops_{dtype.__name__}", test_unary_ops, devices=devices, dtype=dtype
1033
- )
1078
+ class TestArithmetic(unittest.TestCase):
1079
+ pass
1034
1080
 
1035
- for dtype in np_float_types:
1036
- add_function_test_register_kernel(
1037
- TestArithmetic, f"test_special_funcs_{dtype.__name__}", test_special_funcs, devices=devices, dtype=dtype
1038
- )
1039
- add_function_test_register_kernel(
1040
- TestArithmetic,
1041
- f"test_special_funcs_2arg_{dtype.__name__}",
1042
- test_special_funcs_2arg,
1043
- devices=devices,
1044
- dtype=dtype,
1045
- )
1046
- add_function_test_register_kernel(
1047
- TestArithmetic, f"test_interp_{dtype.__name__}", test_interp, devices=devices, dtype=dtype
1048
- )
1049
- add_function_test_register_kernel(
1050
- TestArithmetic, f"test_float_to_int_{dtype.__name__}", test_float_to_int, devices=devices, dtype=dtype
1051
- )
1052
1081
 
1053
- for dtype in np_scalar_types:
1054
- add_function_test_register_kernel(
1055
- TestArithmetic, f"test_clamp_{dtype.__name__}", test_clamp, devices=devices, dtype=dtype
1056
- )
1057
- add_function_test_register_kernel(
1058
- TestArithmetic, f"test_nonzero_{dtype.__name__}", test_nonzero, devices=devices, dtype=dtype
1059
- )
1060
- add_function_test(TestArithmetic, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
1061
- add_function_test_register_kernel(
1062
- TestArithmetic, f"test_binary_ops_{dtype.__name__}", test_binary_ops, devices=devices, dtype=dtype
1063
- )
1082
+ # these unary ops only make sense for signed values:
1083
+ for dtype in np_signed_int_types + np_float_types:
1084
+ add_function_test_register_kernel(
1085
+ TestArithmetic, f"test_unary_ops_{dtype.__name__}", test_unary_ops, devices=devices, dtype=dtype
1086
+ )
1064
1087
 
1065
- return TestArithmetic
1088
+ for dtype in np_float_types:
1089
+ add_function_test_register_kernel(
1090
+ TestArithmetic, f"test_special_funcs_{dtype.__name__}", test_special_funcs, devices=devices, dtype=dtype
1091
+ )
1092
+ add_function_test_register_kernel(
1093
+ TestArithmetic,
1094
+ f"test_special_funcs_2arg_{dtype.__name__}",
1095
+ test_special_funcs_2arg,
1096
+ devices=devices,
1097
+ dtype=dtype,
1098
+ )
1099
+ add_function_test_register_kernel(
1100
+ TestArithmetic, f"test_interp_{dtype.__name__}", test_interp, devices=devices, dtype=dtype
1101
+ )
1102
+ add_function_test_register_kernel(
1103
+ TestArithmetic, f"test_float_to_int_{dtype.__name__}", test_float_to_int, devices=devices, dtype=dtype
1104
+ )
1105
+ add_function_test_register_kernel(
1106
+ TestArithmetic, f"test_infinity_{dtype.__name__}", test_infinity, devices=devices, dtype=dtype
1107
+ )
1108
+
1109
+ for dtype in np_scalar_types:
1110
+ add_function_test_register_kernel(
1111
+ TestArithmetic, f"test_clamp_{dtype.__name__}", test_clamp, devices=devices, dtype=dtype
1112
+ )
1113
+ add_function_test_register_kernel(
1114
+ TestArithmetic, f"test_nonzero_{dtype.__name__}", test_nonzero, devices=devices, dtype=dtype
1115
+ )
1116
+ add_function_test(TestArithmetic, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
1117
+ add_function_test_register_kernel(
1118
+ TestArithmetic, f"test_binary_ops_{dtype.__name__}", test_binary_ops, devices=devices, dtype=dtype
1119
+ )
1066
1120
 
1067
1121
 
1068
1122
  if __name__ == "__main__":
1069
- c = register(unittest.TestCase)
1123
+ wp.build.clear_kernel_cache()
1070
1124
  unittest.main(verbosity=2, failfast=False)