warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +15 -7
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +22 -443
- warp/build_dll.py +384 -0
- warp/builtins.py +998 -488
- warp/codegen.py +1307 -739
- warp/config.py +5 -3
- warp/constants.py +6 -0
- warp/context.py +1291 -548
- warp/dlpack.py +31 -31
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +164 -55
- warp/native/builtin.h +150 -174
- warp/native/bvh.cpp +75 -328
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +37 -45
- warp/native/clang/clang.cpp +136 -24
- warp/native/crt.cpp +1 -76
- warp/native/crt.h +111 -104
- warp/native/cuda_crt.h +1049 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -949
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/initializer_array.h +2 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +13 -16
- warp/native/marching.cu +157 -161
- warp/native/mat.h +119 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +108 -83
- warp/native/mesh.cu +243 -6
- warp/native/mesh.h +1547 -458
- warp/native/nanovdb/NanoVDB.h +1 -1
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +157 -0
- warp/native/reduce.cu +348 -0
- warp/native/runlength_encode.cpp +62 -0
- warp/native/runlength_encode.cu +46 -0
- warp/native/scan.cu +11 -13
- warp/native/scan.h +1 -0
- warp/native/solid_angle.h +442 -0
- warp/native/sort.cpp +13 -0
- warp/native/sort.cu +9 -1
- warp/native/sparse.cpp +338 -0
- warp/native/sparse.cu +545 -0
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +30 -0
- warp/native/vec.h +126 -24
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +658 -53
- warp/native/warp.cu +660 -68
- warp/native/warp.h +112 -12
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +392 -152
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +21 -8
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +161 -19
- warp/sim/model.py +795 -291
- warp/sim/optimizer.py +2 -6
- warp/sim/render.py +65 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +1227 -0
- warp/stubs.py +665 -223
- warp/tape.py +66 -15
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +1497 -211
- warp/tests/test_array_reduce.py +150 -0
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +75 -43
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +233 -128
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +136 -108
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -74
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +180 -116
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +577 -24
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +251 -15
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +508 -2778
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +325 -34
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +190 -0
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +460 -0
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +331 -85
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -1987
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +72 -30
- warp/types.py +1744 -713
- warp/utils.py +360 -350
- warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/bin/warp-clang.exp +0 -0
- warp/bin/warp-clang.lib +0 -0
- warp/bin/warp.exp +0 -0
- warp/bin/warp.lib +0 -0
- warp/tests/test_all.py +0 -215
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.9.0.dist-info/METADATA +0 -20
- warp_lang-0.9.0.dist-info/RECORD +0 -177
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/tests/test_arithmetic.py
CHANGED
|
@@ -5,9 +5,13 @@
|
|
|
5
5
|
# distribution of this software and related documentation without an express
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
|
|
8
|
+
import math
|
|
9
|
+
import unittest
|
|
10
|
+
|
|
8
11
|
import numpy as np
|
|
12
|
+
|
|
9
13
|
import warp as wp
|
|
10
|
-
from warp.tests.
|
|
14
|
+
from warp.tests.unittest_utils import *
|
|
11
15
|
|
|
12
16
|
wp.init()
|
|
13
17
|
|
|
@@ -34,22 +38,21 @@ np_float_types = [np.float16, np.float32, np.float64]
|
|
|
34
38
|
np_scalar_types = np_int_types + np_float_types
|
|
35
39
|
|
|
36
40
|
|
|
37
|
-
def randvals(shape, dtype):
|
|
41
|
+
def randvals(rng, shape, dtype):
|
|
38
42
|
if dtype in np_float_types:
|
|
39
|
-
return
|
|
43
|
+
return rng.standard_normal(size=shape).astype(dtype)
|
|
40
44
|
elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
|
|
41
|
-
return
|
|
42
|
-
return
|
|
45
|
+
return rng.integers(1, high=3, size=shape, dtype=dtype)
|
|
46
|
+
return rng.integers(1, high=5, size=shape, dtype=dtype)
|
|
43
47
|
|
|
44
48
|
|
|
45
49
|
kernel_cache = dict()
|
|
46
50
|
|
|
47
51
|
|
|
48
52
|
def getkernel(func, suffix=""):
|
|
49
|
-
module = wp.get_module(func.__module__)
|
|
50
53
|
key = func.__name__ + "_" + suffix
|
|
51
54
|
if key not in kernel_cache:
|
|
52
|
-
kernel_cache[key] = wp.Kernel(func=func, key=key
|
|
55
|
+
kernel_cache[key] = wp.Kernel(func=func, key=key)
|
|
53
56
|
return kernel_cache[key]
|
|
54
57
|
|
|
55
58
|
|
|
@@ -77,7 +80,7 @@ def get_select_kernel2(dtype):
|
|
|
77
80
|
|
|
78
81
|
|
|
79
82
|
def test_arrays(test, device, dtype):
|
|
80
|
-
np.random.
|
|
83
|
+
rng = np.random.default_rng(123)
|
|
81
84
|
|
|
82
85
|
tol = {
|
|
83
86
|
np.float16: 1.0e-3,
|
|
@@ -86,14 +89,14 @@ def test_arrays(test, device, dtype):
|
|
|
86
89
|
}.get(dtype, 0)
|
|
87
90
|
|
|
88
91
|
wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
|
|
89
|
-
arr_np = randvals((10, 5), dtype)
|
|
92
|
+
arr_np = randvals(rng, (10, 5), dtype)
|
|
90
93
|
arr = wp.array(arr_np, dtype=wptype, requires_grad=True, device=device)
|
|
91
94
|
|
|
92
95
|
assert_np_equal(arr.numpy(), arr_np, tol=tol)
|
|
93
96
|
|
|
94
97
|
|
|
95
98
|
def test_unary_ops(test, device, dtype, register_kernels=False):
|
|
96
|
-
np.random.
|
|
99
|
+
rng = np.random.default_rng(123)
|
|
97
100
|
|
|
98
101
|
tol = {
|
|
99
102
|
np.float16: 5.0e-3,
|
|
@@ -128,10 +131,12 @@ def test_unary_ops(test, device, dtype, register_kernels=False):
|
|
|
128
131
|
return
|
|
129
132
|
|
|
130
133
|
if dtype in np_float_types:
|
|
131
|
-
inputs = wp.array(
|
|
134
|
+
inputs = wp.array(
|
|
135
|
+
rng.standard_normal(size=(5, 10)).astype(dtype), dtype=wptype, requires_grad=True, device=device
|
|
136
|
+
)
|
|
132
137
|
else:
|
|
133
138
|
inputs = wp.array(
|
|
134
|
-
|
|
139
|
+
rng.integers(-2, high=3, size=(5, 10), dtype=dtype), dtype=wptype, requires_grad=True, device=device
|
|
135
140
|
)
|
|
136
141
|
outputs = wp.zeros_like(inputs)
|
|
137
142
|
|
|
@@ -207,7 +212,7 @@ def test_unary_ops(test, device, dtype, register_kernels=False):
|
|
|
207
212
|
|
|
208
213
|
|
|
209
214
|
def test_nonzero(test, device, dtype, register_kernels=False):
|
|
210
|
-
np.random.
|
|
215
|
+
rng = np.random.default_rng(123)
|
|
211
216
|
|
|
212
217
|
tol = {
|
|
213
218
|
np.float16: 5.0e-3,
|
|
@@ -231,7 +236,7 @@ def test_nonzero(test, device, dtype, register_kernels=False):
|
|
|
231
236
|
if register_kernels:
|
|
232
237
|
return
|
|
233
238
|
|
|
234
|
-
inputs = wp.array(
|
|
239
|
+
inputs = wp.array(rng.integers(-2, high=3, size=10).astype(dtype), dtype=wptype, requires_grad=True, device=device)
|
|
235
240
|
outputs = wp.zeros_like(inputs)
|
|
236
241
|
|
|
237
242
|
wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
|
|
@@ -253,10 +258,10 @@ def test_nonzero(test, device, dtype, register_kernels=False):
|
|
|
253
258
|
|
|
254
259
|
|
|
255
260
|
def test_binary_ops(test, device, dtype, register_kernels=False):
|
|
256
|
-
np.random.
|
|
261
|
+
rng = np.random.default_rng(123)
|
|
257
262
|
|
|
258
263
|
tol = {
|
|
259
|
-
np.float16:
|
|
264
|
+
np.float16: 5.0e-2,
|
|
260
265
|
np.float32: 1.0e-6,
|
|
261
266
|
np.float64: 1.0e-8,
|
|
262
267
|
}.get(dtype, 0)
|
|
@@ -302,11 +307,11 @@ def test_binary_ops(test, device, dtype, register_kernels=False):
|
|
|
302
307
|
if register_kernels:
|
|
303
308
|
return
|
|
304
309
|
|
|
305
|
-
vals1 = randvals([8, 10], dtype)
|
|
310
|
+
vals1 = randvals(rng, [8, 10], dtype)
|
|
306
311
|
if dtype in [np_unsigned_int_types]:
|
|
307
|
-
vals2 = vals1 + randvals([8, 10], dtype)
|
|
312
|
+
vals2 = vals1 + randvals(rng, [8, 10], dtype)
|
|
308
313
|
else:
|
|
309
|
-
vals2 = np.abs(randvals([8, 10], dtype))
|
|
314
|
+
vals2 = np.abs(randvals(rng, [8, 10], dtype))
|
|
310
315
|
|
|
311
316
|
in1 = wp.array(vals1, dtype=wptype, requires_grad=True, device=device)
|
|
312
317
|
in2 = wp.array(vals2, dtype=wptype, requires_grad=True, device=device)
|
|
@@ -458,7 +463,7 @@ def test_binary_ops(test, device, dtype, register_kernels=False):
|
|
|
458
463
|
|
|
459
464
|
|
|
460
465
|
def test_special_funcs(test, device, dtype, register_kernels=False):
|
|
461
|
-
np.random.
|
|
466
|
+
rng = np.random.default_rng(123)
|
|
462
467
|
|
|
463
468
|
tol = {
|
|
464
469
|
np.float16: 1.0e-2,
|
|
@@ -488,6 +493,7 @@ def test_special_funcs(test, device, dtype, register_kernels=False):
|
|
|
488
493
|
outputs[11, i] = wptype(2) * wp.tanh(inputs[11, i])
|
|
489
494
|
outputs[12, i] = wptype(2) * wp.acos(inputs[12, i])
|
|
490
495
|
outputs[13, i] = wptype(2) * wp.asin(inputs[13, i])
|
|
496
|
+
outputs[14, i] = wptype(2) * wp.cbrt(inputs[14, i])
|
|
491
497
|
|
|
492
498
|
kernel = getkernel(check_special_funcs, suffix=dtype.__name__)
|
|
493
499
|
output_select_kernel = get_select_kernel2(wptype)
|
|
@@ -495,8 +501,8 @@ def test_special_funcs(test, device, dtype, register_kernels=False):
|
|
|
495
501
|
if register_kernels:
|
|
496
502
|
return
|
|
497
503
|
|
|
498
|
-
invals =
|
|
499
|
-
invals[[0, 1, 2, 7]] = 0.1 + np.abs(invals[[0, 1, 2, 7]])
|
|
504
|
+
invals = rng.normal(size=(15, 10)).astype(dtype)
|
|
505
|
+
invals[[0, 1, 2, 7, 14]] = 0.1 + np.abs(invals[[0, 1, 2, 7, 14]])
|
|
500
506
|
invals[12] = np.clip(invals[12], -0.9, 0.9)
|
|
501
507
|
invals[13] = np.clip(invals[13], -0.9, 0.9)
|
|
502
508
|
inputs = wp.array(invals, dtype=wptype, requires_grad=True, device=device)
|
|
@@ -518,6 +524,7 @@ def test_special_funcs(test, device, dtype, register_kernels=False):
|
|
|
518
524
|
assert_np_equal(outputs.numpy()[11], 2 * np.tanh(inputs.numpy()[11]), tol=tol)
|
|
519
525
|
assert_np_equal(outputs.numpy()[12], 2 * np.arccos(inputs.numpy()[12]), tol=tol)
|
|
520
526
|
assert_np_equal(outputs.numpy()[13], 2 * np.arcsin(inputs.numpy()[13]), tol=tol)
|
|
527
|
+
assert_np_equal(outputs.numpy()[14], 2 * np.cbrt(inputs.numpy()[14]), tol=tol)
|
|
521
528
|
|
|
522
529
|
out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
|
|
523
530
|
if dtype in np_float_types:
|
|
@@ -694,9 +701,22 @@ def test_special_funcs(test, device, dtype, register_kernels=False):
|
|
|
694
701
|
assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=6 * tol)
|
|
695
702
|
tape.zero()
|
|
696
703
|
|
|
704
|
+
# cbrt:
|
|
705
|
+
tape = wp.Tape()
|
|
706
|
+
with tape:
|
|
707
|
+
wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
|
|
708
|
+
wp.launch(output_select_kernel, dim=1, inputs=[outputs, 14, i], outputs=[out], device=device)
|
|
709
|
+
|
|
710
|
+
tape.backward(loss=out)
|
|
711
|
+
expected = np.zeros_like(inputs.numpy())
|
|
712
|
+
cbrt = np.cbrt(inputs.numpy()[14, i], dtype=np.dtype(dtype))
|
|
713
|
+
expected[14, i] = (2.0 / 3.0) * (1.0 / (cbrt * cbrt))
|
|
714
|
+
assert_np_equal(tape.gradients[inputs].numpy(), expected, tol=tol)
|
|
715
|
+
tape.zero()
|
|
716
|
+
|
|
697
717
|
|
|
698
718
|
def test_special_funcs_2arg(test, device, dtype, register_kernels=False):
|
|
699
|
-
np.random.
|
|
719
|
+
rng = np.random.default_rng(123)
|
|
700
720
|
|
|
701
721
|
tol = {
|
|
702
722
|
np.float16: 1.0e-2,
|
|
@@ -722,8 +742,8 @@ def test_special_funcs_2arg(test, device, dtype, register_kernels=False):
|
|
|
722
742
|
if register_kernels:
|
|
723
743
|
return
|
|
724
744
|
|
|
725
|
-
in1 = wp.array(np.abs(randvals([2, 10], dtype)), dtype=wptype, requires_grad=True, device=device)
|
|
726
|
-
in2 = wp.array(randvals([2, 10], dtype), dtype=wptype, requires_grad=True, device=device)
|
|
745
|
+
in1 = wp.array(np.abs(randvals(rng, [2, 10], dtype)), dtype=wptype, requires_grad=True, device=device)
|
|
746
|
+
in2 = wp.array(randvals(rng, [2, 10], dtype), dtype=wptype, requires_grad=True, device=device)
|
|
727
747
|
outputs = wp.zeros_like(in1)
|
|
728
748
|
|
|
729
749
|
wp.launch(kernel, dim=1, inputs=[in1, in2], outputs=[outputs], device=device)
|
|
@@ -763,7 +783,7 @@ def test_special_funcs_2arg(test, device, dtype, register_kernels=False):
|
|
|
763
783
|
|
|
764
784
|
|
|
765
785
|
def test_float_to_int(test, device, dtype, register_kernels=False):
|
|
766
|
-
np.random.
|
|
786
|
+
rng = np.random.default_rng(123)
|
|
767
787
|
|
|
768
788
|
tol = {
|
|
769
789
|
np.float16: 5.0e-3,
|
|
@@ -783,6 +803,7 @@ def test_float_to_int(test, device, dtype, register_kernels=False):
|
|
|
783
803
|
outputs[2, i] = wp.trunc(inputs[2, i])
|
|
784
804
|
outputs[3, i] = wp.floor(inputs[3, i])
|
|
785
805
|
outputs[4, i] = wp.ceil(inputs[4, i])
|
|
806
|
+
outputs[5, i] = wp.frac(inputs[5, i])
|
|
786
807
|
|
|
787
808
|
kernel = getkernel(check_float_to_int, suffix=dtype.__name__)
|
|
788
809
|
output_select_kernel = get_select_kernel2(wptype)
|
|
@@ -790,7 +811,7 @@ def test_float_to_int(test, device, dtype, register_kernels=False):
|
|
|
790
811
|
if register_kernels:
|
|
791
812
|
return
|
|
792
813
|
|
|
793
|
-
inputs = wp.array(
|
|
814
|
+
inputs = wp.array(rng.standard_normal(size=(6, 10)).astype(dtype), dtype=wptype, requires_grad=True, device=device)
|
|
794
815
|
outputs = wp.zeros_like(inputs)
|
|
795
816
|
|
|
796
817
|
wp.launch(kernel, dim=1, inputs=[inputs], outputs=[outputs], device=device)
|
|
@@ -800,6 +821,7 @@ def test_float_to_int(test, device, dtype, register_kernels=False):
|
|
|
800
821
|
assert_np_equal(outputs.numpy()[2], np.trunc(inputs.numpy()[2]))
|
|
801
822
|
assert_np_equal(outputs.numpy()[3], np.floor(inputs.numpy()[3]))
|
|
802
823
|
assert_np_equal(outputs.numpy()[4], np.ceil(inputs.numpy()[4]))
|
|
824
|
+
assert_np_equal(outputs.numpy()[5], np.modf(inputs.numpy()[5])[0])
|
|
803
825
|
|
|
804
826
|
# all the gradients should be zero as these functions are piecewise constant:
|
|
805
827
|
|
|
@@ -816,8 +838,38 @@ def test_float_to_int(test, device, dtype, register_kernels=False):
|
|
|
816
838
|
tape.zero()
|
|
817
839
|
|
|
818
840
|
|
|
841
|
+
def test_infinity(test, device, dtype, register_kernels=False):
|
|
842
|
+
wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
|
|
843
|
+
|
|
844
|
+
def check_infinity(
|
|
845
|
+
outputs: wp.array(dtype=wptype),
|
|
846
|
+
):
|
|
847
|
+
outputs[0] = wptype(wp.inf)
|
|
848
|
+
outputs[1] = wptype(-wp.inf)
|
|
849
|
+
outputs[2] = wptype(2.0 * wp.inf)
|
|
850
|
+
outputs[3] = wptype(-2.0 * wp.inf)
|
|
851
|
+
outputs[4] = wptype(2.0 / 0.0)
|
|
852
|
+
outputs[5] = wptype(-2.0 / 0.0)
|
|
853
|
+
|
|
854
|
+
kernel = getkernel(check_infinity, suffix=dtype.__name__)
|
|
855
|
+
|
|
856
|
+
if register_kernels:
|
|
857
|
+
return
|
|
858
|
+
|
|
859
|
+
outputs = wp.zeros(6, dtype=wptype, device=device)
|
|
860
|
+
|
|
861
|
+
wp.launch(kernel, dim=1, inputs=[], outputs=[outputs], device=device)
|
|
862
|
+
|
|
863
|
+
test.assertEqual(outputs.numpy()[0], math.inf)
|
|
864
|
+
test.assertEqual(outputs.numpy()[1], -math.inf)
|
|
865
|
+
test.assertEqual(outputs.numpy()[2], math.inf)
|
|
866
|
+
test.assertEqual(outputs.numpy()[3], -math.inf)
|
|
867
|
+
test.assertEqual(outputs.numpy()[4], math.inf)
|
|
868
|
+
test.assertEqual(outputs.numpy()[5], -math.inf)
|
|
869
|
+
|
|
870
|
+
|
|
819
871
|
def test_interp(test, device, dtype, register_kernels=False):
|
|
820
|
-
np.random.
|
|
872
|
+
rng = np.random.default_rng(123)
|
|
821
873
|
|
|
822
874
|
tol = {
|
|
823
875
|
np.float16: 1.0e-2,
|
|
@@ -844,11 +896,11 @@ def test_interp(test, device, dtype, register_kernels=False):
|
|
|
844
896
|
if register_kernels:
|
|
845
897
|
return
|
|
846
898
|
|
|
847
|
-
e0 = randvals([2, 10], dtype)
|
|
848
|
-
e1 = e0 + randvals([2, 10], dtype) + 0.1
|
|
899
|
+
e0 = randvals(rng, [2, 10], dtype)
|
|
900
|
+
e1 = e0 + randvals(rng, [2, 10], dtype) + 0.1
|
|
849
901
|
in1 = wp.array(e0, dtype=wptype, requires_grad=True, device=device)
|
|
850
902
|
in2 = wp.array(e1, dtype=wptype, requires_grad=True, device=device)
|
|
851
|
-
in3 = wp.array(randvals([2, 10], dtype), dtype=wptype, requires_grad=True, device=device)
|
|
903
|
+
in3 = wp.array(randvals(rng, [2, 10], dtype), dtype=wptype, requires_grad=True, device=device)
|
|
852
904
|
|
|
853
905
|
outputs = wp.zeros_like(in1)
|
|
854
906
|
|
|
@@ -948,7 +1000,7 @@ def test_interp(test, device, dtype, register_kernels=False):
|
|
|
948
1000
|
|
|
949
1001
|
|
|
950
1002
|
def test_clamp(test, device, dtype, register_kernels=False):
|
|
951
|
-
np.random.
|
|
1003
|
+
rng = np.random.default_rng(123)
|
|
952
1004
|
|
|
953
1005
|
tol = {
|
|
954
1006
|
np.float16: 5.0e-3,
|
|
@@ -974,9 +1026,9 @@ def test_clamp(test, device, dtype, register_kernels=False):
|
|
|
974
1026
|
if register_kernels:
|
|
975
1027
|
return
|
|
976
1028
|
|
|
977
|
-
in1 = wp.array(randvals([100], dtype), dtype=wptype, requires_grad=True, device=device)
|
|
978
|
-
starts = randvals([100], dtype)
|
|
979
|
-
diffs = np.abs(randvals([100], dtype))
|
|
1029
|
+
in1 = wp.array(randvals(rng, [100], dtype), dtype=wptype, requires_grad=True, device=device)
|
|
1030
|
+
starts = randvals(rng, [100], dtype)
|
|
1031
|
+
diffs = np.abs(randvals(rng, [100], dtype))
|
|
980
1032
|
in2 = wp.array(starts, dtype=wptype, requires_grad=True, device=device)
|
|
981
1033
|
in3 = wp.array(starts + diffs, dtype=wptype, requires_grad=True, device=device)
|
|
982
1034
|
outputs = wp.zeros_like(in1)
|
|
@@ -1020,51 +1072,53 @@ def test_clamp(test, device, dtype, register_kernels=False):
|
|
|
1020
1072
|
tape.zero()
|
|
1021
1073
|
|
|
1022
1074
|
|
|
1023
|
-
|
|
1024
|
-
devices = get_test_devices()
|
|
1075
|
+
devices = get_test_devices()
|
|
1025
1076
|
|
|
1026
|
-
class TestArithmetic(parent):
|
|
1027
|
-
pass
|
|
1028
1077
|
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
add_function_test_register_kernel(
|
|
1032
|
-
TestArithmetic, f"test_unary_ops_{dtype.__name__}", test_unary_ops, devices=devices, dtype=dtype
|
|
1033
|
-
)
|
|
1078
|
+
class TestArithmetic(unittest.TestCase):
|
|
1079
|
+
pass
|
|
1034
1080
|
|
|
1035
|
-
for dtype in np_float_types:
|
|
1036
|
-
add_function_test_register_kernel(
|
|
1037
|
-
TestArithmetic, f"test_special_funcs_{dtype.__name__}", test_special_funcs, devices=devices, dtype=dtype
|
|
1038
|
-
)
|
|
1039
|
-
add_function_test_register_kernel(
|
|
1040
|
-
TestArithmetic,
|
|
1041
|
-
f"test_special_funcs_2arg_{dtype.__name__}",
|
|
1042
|
-
test_special_funcs_2arg,
|
|
1043
|
-
devices=devices,
|
|
1044
|
-
dtype=dtype,
|
|
1045
|
-
)
|
|
1046
|
-
add_function_test_register_kernel(
|
|
1047
|
-
TestArithmetic, f"test_interp_{dtype.__name__}", test_interp, devices=devices, dtype=dtype
|
|
1048
|
-
)
|
|
1049
|
-
add_function_test_register_kernel(
|
|
1050
|
-
TestArithmetic, f"test_float_to_int_{dtype.__name__}", test_float_to_int, devices=devices, dtype=dtype
|
|
1051
|
-
)
|
|
1052
1081
|
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
TestArithmetic, f"test_nonzero_{dtype.__name__}", test_nonzero, devices=devices, dtype=dtype
|
|
1059
|
-
)
|
|
1060
|
-
add_function_test(TestArithmetic, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
|
|
1061
|
-
add_function_test_register_kernel(
|
|
1062
|
-
TestArithmetic, f"test_binary_ops_{dtype.__name__}", test_binary_ops, devices=devices, dtype=dtype
|
|
1063
|
-
)
|
|
1082
|
+
# these unary ops only make sense for signed values:
|
|
1083
|
+
for dtype in np_signed_int_types + np_float_types:
|
|
1084
|
+
add_function_test_register_kernel(
|
|
1085
|
+
TestArithmetic, f"test_unary_ops_{dtype.__name__}", test_unary_ops, devices=devices, dtype=dtype
|
|
1086
|
+
)
|
|
1064
1087
|
|
|
1065
|
-
|
|
1088
|
+
for dtype in np_float_types:
|
|
1089
|
+
add_function_test_register_kernel(
|
|
1090
|
+
TestArithmetic, f"test_special_funcs_{dtype.__name__}", test_special_funcs, devices=devices, dtype=dtype
|
|
1091
|
+
)
|
|
1092
|
+
add_function_test_register_kernel(
|
|
1093
|
+
TestArithmetic,
|
|
1094
|
+
f"test_special_funcs_2arg_{dtype.__name__}",
|
|
1095
|
+
test_special_funcs_2arg,
|
|
1096
|
+
devices=devices,
|
|
1097
|
+
dtype=dtype,
|
|
1098
|
+
)
|
|
1099
|
+
add_function_test_register_kernel(
|
|
1100
|
+
TestArithmetic, f"test_interp_{dtype.__name__}", test_interp, devices=devices, dtype=dtype
|
|
1101
|
+
)
|
|
1102
|
+
add_function_test_register_kernel(
|
|
1103
|
+
TestArithmetic, f"test_float_to_int_{dtype.__name__}", test_float_to_int, devices=devices, dtype=dtype
|
|
1104
|
+
)
|
|
1105
|
+
add_function_test_register_kernel(
|
|
1106
|
+
TestArithmetic, f"test_infinity_{dtype.__name__}", test_infinity, devices=devices, dtype=dtype
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
for dtype in np_scalar_types:
|
|
1110
|
+
add_function_test_register_kernel(
|
|
1111
|
+
TestArithmetic, f"test_clamp_{dtype.__name__}", test_clamp, devices=devices, dtype=dtype
|
|
1112
|
+
)
|
|
1113
|
+
add_function_test_register_kernel(
|
|
1114
|
+
TestArithmetic, f"test_nonzero_{dtype.__name__}", test_nonzero, devices=devices, dtype=dtype
|
|
1115
|
+
)
|
|
1116
|
+
add_function_test(TestArithmetic, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
|
|
1117
|
+
add_function_test_register_kernel(
|
|
1118
|
+
TestArithmetic, f"test_binary_ops_{dtype.__name__}", test_binary_ops, devices=devices, dtype=dtype
|
|
1119
|
+
)
|
|
1066
1120
|
|
|
1067
1121
|
|
|
1068
1122
|
if __name__ == "__main__":
|
|
1069
|
-
|
|
1123
|
+
wp.build.clear_kernel_cache()
|
|
1070
1124
|
unittest.main(verbosity=2, failfast=False)
|