warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +15 -7
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +22 -443
- warp/build_dll.py +384 -0
- warp/builtins.py +998 -488
- warp/codegen.py +1307 -739
- warp/config.py +5 -3
- warp/constants.py +6 -0
- warp/context.py +1291 -548
- warp/dlpack.py +31 -31
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +164 -55
- warp/native/builtin.h +150 -174
- warp/native/bvh.cpp +75 -328
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +37 -45
- warp/native/clang/clang.cpp +136 -24
- warp/native/crt.cpp +1 -76
- warp/native/crt.h +111 -104
- warp/native/cuda_crt.h +1049 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -949
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/initializer_array.h +2 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +13 -16
- warp/native/marching.cu +157 -161
- warp/native/mat.h +119 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +108 -83
- warp/native/mesh.cu +243 -6
- warp/native/mesh.h +1547 -458
- warp/native/nanovdb/NanoVDB.h +1 -1
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +157 -0
- warp/native/reduce.cu +348 -0
- warp/native/runlength_encode.cpp +62 -0
- warp/native/runlength_encode.cu +46 -0
- warp/native/scan.cu +11 -13
- warp/native/scan.h +1 -0
- warp/native/solid_angle.h +442 -0
- warp/native/sort.cpp +13 -0
- warp/native/sort.cu +9 -1
- warp/native/sparse.cpp +338 -0
- warp/native/sparse.cu +545 -0
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +30 -0
- warp/native/vec.h +126 -24
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +658 -53
- warp/native/warp.cu +660 -68
- warp/native/warp.h +112 -12
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +392 -152
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +21 -8
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +161 -19
- warp/sim/model.py +795 -291
- warp/sim/optimizer.py +2 -6
- warp/sim/render.py +65 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +1227 -0
- warp/stubs.py +665 -223
- warp/tape.py +66 -15
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +1497 -211
- warp/tests/test_array_reduce.py +150 -0
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +75 -43
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +233 -128
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +136 -108
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -74
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +180 -116
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +577 -24
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +251 -15
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +508 -2778
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +325 -34
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +190 -0
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +460 -0
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +331 -85
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -1987
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +72 -30
- warp/types.py +1744 -713
- warp/utils.py +360 -350
- warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/bin/warp-clang.exp +0 -0
- warp/bin/warp-clang.lib +0 -0
- warp/bin/warp.exp +0 -0
- warp/bin/warp.lib +0 -0
- warp/tests/test_all.py +0 -215
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.9.0.dist-info/METADATA +0 -20
- warp_lang-0.9.0.dist-info/RECORD +0 -177
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/builtin.h
CHANGED
|
@@ -46,7 +46,6 @@ __device__ void __debugbreak() {}
|
|
|
46
46
|
namespace wp
|
|
47
47
|
{
|
|
48
48
|
|
|
49
|
-
|
|
50
49
|
// numeric types (used from generated kernels)
|
|
51
50
|
typedef float float32;
|
|
52
51
|
typedef double float64;
|
|
@@ -141,7 +140,7 @@ static_assert(sizeof(half) == 2, "Size of half / float16 type must be 2-bytes");
|
|
|
141
140
|
|
|
142
141
|
typedef half float16;
|
|
143
142
|
|
|
144
|
-
#if __CUDA_ARCH__
|
|
143
|
+
#if defined(__CUDA_ARCH__)
|
|
145
144
|
|
|
146
145
|
CUDA_CALLABLE inline half float_to_half(float x)
|
|
147
146
|
{
|
|
@@ -157,95 +156,38 @@ CUDA_CALLABLE inline float half_to_float(half x)
|
|
|
157
156
|
return val;
|
|
158
157
|
}
|
|
159
158
|
|
|
160
|
-
#
|
|
159
|
+
#elif defined(__clang__)
|
|
161
160
|
|
|
162
|
-
//
|
|
161
|
+
// _Float16 is Clang's native half-precision floating-point type
|
|
163
162
|
inline half float_to_half(float x)
|
|
164
163
|
{
|
|
165
|
-
union fp32
|
|
166
|
-
{
|
|
167
|
-
uint32 u;
|
|
168
|
-
float f;
|
|
169
164
|
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
unsigned int mantissa : 23;
|
|
173
|
-
unsigned int exponent : 8;
|
|
174
|
-
unsigned int sign : 1;
|
|
175
|
-
};
|
|
176
|
-
};
|
|
177
|
-
|
|
178
|
-
fp32 f;
|
|
179
|
-
f.f = x;
|
|
180
|
-
|
|
181
|
-
fp32 f32infty = { 255 << 23 };
|
|
182
|
-
fp32 f16infty = { 31 << 23 };
|
|
183
|
-
fp32 magic = { 15 << 23 };
|
|
184
|
-
uint32 sign_mask = 0x80000000u;
|
|
185
|
-
uint32 round_mask = ~0xfffu;
|
|
186
|
-
half o;
|
|
187
|
-
|
|
188
|
-
uint32 sign = f.u & sign_mask;
|
|
189
|
-
f.u ^= sign;
|
|
190
|
-
|
|
191
|
-
// NOTE all the integer compares in this function can be safely
|
|
192
|
-
// compiled into signed compares since all operands are below
|
|
193
|
-
// 0x80000000. Important if you want fast straight SSE2 code
|
|
194
|
-
// (since there's no unsigned PCMPGTD).
|
|
195
|
-
|
|
196
|
-
if (f.u >= f32infty.u) // Inf or NaN (all exponent bits set)
|
|
197
|
-
o.u = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
|
|
198
|
-
else // (De)normalized number or zero
|
|
199
|
-
{
|
|
200
|
-
f.u &= round_mask;
|
|
201
|
-
f.f *= magic.f;
|
|
202
|
-
f.u -= round_mask;
|
|
203
|
-
if (f.u > f16infty.u) f.u = f16infty.u; // Clamp to signed infinity if overflowed
|
|
204
|
-
|
|
205
|
-
o.u = f.u >> 13; // Take the bits!
|
|
206
|
-
}
|
|
207
|
-
|
|
208
|
-
o.u |= sign >> 16;
|
|
209
|
-
return o;
|
|
165
|
+
_Float16 f16 = static_cast<_Float16>(x);
|
|
166
|
+
return *reinterpret_cast<half*>(&f16);
|
|
210
167
|
}
|
|
211
168
|
|
|
212
|
-
|
|
213
169
|
inline float half_to_float(half h)
|
|
214
170
|
{
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
float f;
|
|
171
|
+
_Float16 f16 = *reinterpret_cast<_Float16*>(&h);
|
|
172
|
+
return static_cast<float>(f16);
|
|
173
|
+
}
|
|
219
174
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
unsigned int sign : 1;
|
|
225
|
-
};
|
|
226
|
-
};
|
|
227
|
-
|
|
228
|
-
static const fp32 magic = { 113 << 23 };
|
|
229
|
-
static const uint32 shifted_exp = 0x7c00 << 13; // exponent mask after shift
|
|
230
|
-
fp32 o;
|
|
231
|
-
|
|
232
|
-
o.u = (h.u & 0x7fff) << 13; // exponent/mantissa bits
|
|
233
|
-
uint32 exp = shifted_exp & o.u; // just the exponent
|
|
234
|
-
o.u += (127 - 15) << 23; // exponent adjust
|
|
235
|
-
|
|
236
|
-
// handle exponent special cases
|
|
237
|
-
if (exp == shifted_exp) // Inf/NaN?
|
|
238
|
-
o.u += (128 - 16) << 23; // extra exp adjust
|
|
239
|
-
else if (exp == 0) // Zero/Denormal?
|
|
240
|
-
{
|
|
241
|
-
o.u += 1 << 23; // extra exp adjust
|
|
242
|
-
o.f -= magic.f; // renormalize
|
|
243
|
-
}
|
|
175
|
+
#else // Native C++ for Warp builtins outside of kernels
|
|
176
|
+
|
|
177
|
+
extern "C" WP_API uint16_t float_to_half_bits(float x);
|
|
178
|
+
extern "C" WP_API float half_bits_to_float(uint16_t u);
|
|
244
179
|
|
|
245
|
-
|
|
246
|
-
|
|
180
|
+
inline half float_to_half(float x)
|
|
181
|
+
{
|
|
182
|
+
half h;
|
|
183
|
+
h.u = float_to_half_bits(x);
|
|
184
|
+
return h;
|
|
247
185
|
}
|
|
248
186
|
|
|
187
|
+
inline float half_to_float(half h)
|
|
188
|
+
{
|
|
189
|
+
return half_bits_to_float(h.u);
|
|
190
|
+
}
|
|
249
191
|
|
|
250
192
|
#endif
|
|
251
193
|
|
|
@@ -353,7 +295,7 @@ inline CUDA_CALLABLE T rshift(T a, T b) { return a>>b; } \
|
|
|
353
295
|
inline CUDA_CALLABLE T invert(T x) { return ~x; } \
|
|
354
296
|
inline CUDA_CALLABLE bool isfinite(T x) { return true; } \
|
|
355
297
|
inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
|
|
356
|
-
inline CUDA_CALLABLE void adj_div(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
|
|
298
|
+
inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret) { } \
|
|
357
299
|
inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
|
|
358
300
|
inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
|
|
359
301
|
inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
|
|
@@ -491,11 +433,6 @@ inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b,
|
|
|
491
433
|
else\
|
|
492
434
|
adj_x += adj_ret;\
|
|
493
435
|
}\
|
|
494
|
-
inline CUDA_CALLABLE void adj_round(T x, T& adj_x, T adj_ret){ }\
|
|
495
|
-
inline CUDA_CALLABLE void adj_rint(T x, T& adj_x, T adj_ret){ }\
|
|
496
|
-
inline CUDA_CALLABLE void adj_trunc(T x, T& adj_x, T adj_ret){ }\
|
|
497
|
-
inline CUDA_CALLABLE void adj_floor(T x, T& adj_x, T adj_ret){ }\
|
|
498
|
-
inline CUDA_CALLABLE void adj_ceil(T x, T& adj_x, T adj_ret){ }\
|
|
499
436
|
inline CUDA_CALLABLE T div(T a, T b)\
|
|
500
437
|
{\
|
|
501
438
|
DO_IF_FPCHECK(\
|
|
@@ -506,10 +443,10 @@ inline CUDA_CALLABLE T div(T a, T b)\
|
|
|
506
443
|
})\
|
|
507
444
|
return a/b;\
|
|
508
445
|
}\
|
|
509
|
-
inline CUDA_CALLABLE void adj_div(T a, T b, T& adj_a, T& adj_b, T adj_ret)\
|
|
446
|
+
inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
|
|
510
447
|
{\
|
|
511
448
|
adj_a += adj_ret/b;\
|
|
512
|
-
adj_b -= adj_ret*(
|
|
449
|
+
adj_b -= adj_ret*(ret)/b;\
|
|
513
450
|
DO_IF_FPCHECK(\
|
|
514
451
|
if (!isfinite(adj_a) || !isfinite(adj_b))\
|
|
515
452
|
{\
|
|
@@ -788,16 +725,16 @@ inline CUDA_CALLABLE double floordiv(double a, double b)
|
|
|
788
725
|
inline CUDA_CALLABLE float leaky_min(float a, float b, float r) { return min(a, b); }
|
|
789
726
|
inline CUDA_CALLABLE float leaky_max(float a, float b, float r) { return max(a, b); }
|
|
790
727
|
|
|
791
|
-
inline CUDA_CALLABLE half abs(half x) { return ::
|
|
792
|
-
inline CUDA_CALLABLE float abs(float x) { return ::
|
|
728
|
+
inline CUDA_CALLABLE half abs(half x) { return ::fabsf(float(x)); }
|
|
729
|
+
inline CUDA_CALLABLE float abs(float x) { return ::fabsf(x); }
|
|
793
730
|
inline CUDA_CALLABLE double abs(double x) { return ::fabs(x); }
|
|
794
731
|
|
|
795
|
-
inline CUDA_CALLABLE float acos(float x){ return ::
|
|
796
|
-
inline CUDA_CALLABLE float asin(float x){ return ::
|
|
797
|
-
inline CUDA_CALLABLE float atan(float x) { return ::
|
|
798
|
-
inline CUDA_CALLABLE float atan2(float y, float x) { return ::
|
|
799
|
-
inline CUDA_CALLABLE float sin(float x) { return ::
|
|
800
|
-
inline CUDA_CALLABLE float cos(float x) { return ::
|
|
732
|
+
inline CUDA_CALLABLE float acos(float x){ return ::acosf(min(max(x, -1.0f), 1.0f)); }
|
|
733
|
+
inline CUDA_CALLABLE float asin(float x){ return ::asinf(min(max(x, -1.0f), 1.0f)); }
|
|
734
|
+
inline CUDA_CALLABLE float atan(float x) { return ::atanf(x); }
|
|
735
|
+
inline CUDA_CALLABLE float atan2(float y, float x) { return ::atan2f(y, x); }
|
|
736
|
+
inline CUDA_CALLABLE float sin(float x) { return ::sinf(x); }
|
|
737
|
+
inline CUDA_CALLABLE float cos(float x) { return ::cosf(x); }
|
|
801
738
|
|
|
802
739
|
inline CUDA_CALLABLE double acos(double x){ return ::acos(min(max(x, -1.0), 1.0)); }
|
|
803
740
|
inline CUDA_CALLABLE double asin(double x){ return ::asin(min(max(x, -1.0), 1.0)); }
|
|
@@ -806,12 +743,12 @@ inline CUDA_CALLABLE double atan2(double y, double x) { return ::atan2(y, x); }
|
|
|
806
743
|
inline CUDA_CALLABLE double sin(double x) { return ::sin(x); }
|
|
807
744
|
inline CUDA_CALLABLE double cos(double x) { return ::cos(x); }
|
|
808
745
|
|
|
809
|
-
inline CUDA_CALLABLE half acos(half x){ return ::
|
|
810
|
-
inline CUDA_CALLABLE half asin(half x){ return ::
|
|
811
|
-
inline CUDA_CALLABLE half atan(half x) { return ::
|
|
812
|
-
inline CUDA_CALLABLE half atan2(half y, half x) { return ::
|
|
813
|
-
inline CUDA_CALLABLE half sin(half x) { return ::
|
|
814
|
-
inline CUDA_CALLABLE half cos(half x) { return ::
|
|
746
|
+
inline CUDA_CALLABLE half acos(half x){ return ::acosf(min(max(float(x), -1.0f), 1.0f)); }
|
|
747
|
+
inline CUDA_CALLABLE half asin(half x){ return ::asinf(min(max(float(x), -1.0f), 1.0f)); }
|
|
748
|
+
inline CUDA_CALLABLE half atan(half x) { return ::atanf(float(x)); }
|
|
749
|
+
inline CUDA_CALLABLE half atan2(half y, half x) { return ::atan2f(float(y), float(x)); }
|
|
750
|
+
inline CUDA_CALLABLE half sin(half x) { return ::sinf(float(x)); }
|
|
751
|
+
inline CUDA_CALLABLE half cos(half x) { return ::cosf(float(x)); }
|
|
815
752
|
|
|
816
753
|
|
|
817
754
|
inline CUDA_CALLABLE float sqrt(float x)
|
|
@@ -823,7 +760,7 @@ inline CUDA_CALLABLE float sqrt(float x)
|
|
|
823
760
|
assert(0);
|
|
824
761
|
}
|
|
825
762
|
#endif
|
|
826
|
-
return ::
|
|
763
|
+
return ::sqrtf(x);
|
|
827
764
|
}
|
|
828
765
|
inline CUDA_CALLABLE double sqrt(double x)
|
|
829
766
|
{
|
|
@@ -845,10 +782,14 @@ inline CUDA_CALLABLE half sqrt(half x)
|
|
|
845
782
|
assert(0);
|
|
846
783
|
}
|
|
847
784
|
#endif
|
|
848
|
-
return ::
|
|
785
|
+
return ::sqrtf(float(x));
|
|
849
786
|
}
|
|
850
787
|
|
|
851
|
-
inline CUDA_CALLABLE float
|
|
788
|
+
inline CUDA_CALLABLE float cbrt(float x) { return ::cbrtf(x); }
|
|
789
|
+
inline CUDA_CALLABLE double cbrt(double x) { return ::cbrt(x); }
|
|
790
|
+
inline CUDA_CALLABLE half cbrt(half x) { return ::cbrtf(float(x)); }
|
|
791
|
+
|
|
792
|
+
inline CUDA_CALLABLE float tan(float x) { return ::tanf(x); }
|
|
852
793
|
inline CUDA_CALLABLE float sinh(float x) { return ::sinhf(x);}
|
|
853
794
|
inline CUDA_CALLABLE float cosh(float x) { return ::coshf(x);}
|
|
854
795
|
inline CUDA_CALLABLE float tanh(float x) { return ::tanhf(x);}
|
|
@@ -862,7 +803,7 @@ inline CUDA_CALLABLE double tanh(double x) { return ::tanh(x);}
|
|
|
862
803
|
inline CUDA_CALLABLE double degrees(double x) { return x * RAD_TO_DEG;}
|
|
863
804
|
inline CUDA_CALLABLE double radians(double x) { return x * DEG_TO_RAD;}
|
|
864
805
|
|
|
865
|
-
inline CUDA_CALLABLE half tan(half x) { return ::
|
|
806
|
+
inline CUDA_CALLABLE half tan(half x) { return ::tanf(float(x)); }
|
|
866
807
|
inline CUDA_CALLABLE half sinh(half x) { return ::sinhf(float(x));}
|
|
867
808
|
inline CUDA_CALLABLE half cosh(half x) { return ::coshf(float(x));}
|
|
868
809
|
inline CUDA_CALLABLE half tanh(half x) { return ::tanhf(float(x));}
|
|
@@ -874,6 +815,21 @@ inline CUDA_CALLABLE float rint(float x) { return ::rintf(x); }
|
|
|
874
815
|
inline CUDA_CALLABLE float trunc(float x) { return ::truncf(x); }
|
|
875
816
|
inline CUDA_CALLABLE float floor(float x) { return ::floorf(x); }
|
|
876
817
|
inline CUDA_CALLABLE float ceil(float x) { return ::ceilf(x); }
|
|
818
|
+
inline CUDA_CALLABLE float frac(float x) { return x - trunc(x); }
|
|
819
|
+
|
|
820
|
+
inline CUDA_CALLABLE double round(double x) { return ::round(x); }
|
|
821
|
+
inline CUDA_CALLABLE double rint(double x) { return ::rint(x); }
|
|
822
|
+
inline CUDA_CALLABLE double trunc(double x) { return ::trunc(x); }
|
|
823
|
+
inline CUDA_CALLABLE double floor(double x) { return ::floor(x); }
|
|
824
|
+
inline CUDA_CALLABLE double ceil(double x) { return ::ceil(x); }
|
|
825
|
+
inline CUDA_CALLABLE double frac(double x) { return x - trunc(x); }
|
|
826
|
+
|
|
827
|
+
inline CUDA_CALLABLE half round(half x) { return ::roundf(float(x)); }
|
|
828
|
+
inline CUDA_CALLABLE half rint(half x) { return ::rintf(float(x)); }
|
|
829
|
+
inline CUDA_CALLABLE half trunc(half x) { return ::truncf(float(x)); }
|
|
830
|
+
inline CUDA_CALLABLE half floor(half x) { return ::floorf(float(x)); }
|
|
831
|
+
inline CUDA_CALLABLE half ceil(half x) { return ::ceilf(float(x)); }
|
|
832
|
+
inline CUDA_CALLABLE half frac(half x) { return float(x) - trunc(float(x)); }
|
|
877
833
|
|
|
878
834
|
#define DECLARE_ADJOINTS(T)\
|
|
879
835
|
inline CUDA_CALLABLE void adj_log(T a, T& adj_a, T adj_ret)\
|
|
@@ -903,11 +859,11 @@ inline CUDA_CALLABLE void adj_log10(T a, T& adj_a, T adj_ret)\
|
|
|
903
859
|
assert(0);\
|
|
904
860
|
})\
|
|
905
861
|
}\
|
|
906
|
-
inline CUDA_CALLABLE void adj_exp(T a, T& adj_a, T adj_ret) { adj_a +=
|
|
907
|
-
inline CUDA_CALLABLE void adj_pow(T a, T b, T& adj_a, T& adj_b, T adj_ret)\
|
|
862
|
+
inline CUDA_CALLABLE void adj_exp(T a, T ret, T& adj_a, T adj_ret) { adj_a += ret*adj_ret; }\
|
|
863
|
+
inline CUDA_CALLABLE void adj_pow(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
|
|
908
864
|
{ \
|
|
909
865
|
adj_a += b*pow(a, b-T(1))*adj_ret;\
|
|
910
|
-
adj_b += log(a)*
|
|
866
|
+
adj_b += log(a)*ret*adj_ret;\
|
|
911
867
|
DO_IF_FPCHECK(if (!isfinite(adj_a) || !isfinite(adj_b))\
|
|
912
868
|
{\
|
|
913
869
|
printf("%s:%d - adj_pow(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
|
|
@@ -1006,20 +962,28 @@ inline CUDA_CALLABLE void adj_cosh(T x, T& adj_x, T adj_ret)\
|
|
|
1006
962
|
{\
|
|
1007
963
|
adj_x += sinh(x)*adj_ret;\
|
|
1008
964
|
}\
|
|
1009
|
-
inline CUDA_CALLABLE void adj_tanh(T x, T& adj_x, T adj_ret)\
|
|
965
|
+
inline CUDA_CALLABLE void adj_tanh(T x, T ret, T& adj_x, T adj_ret)\
|
|
1010
966
|
{\
|
|
1011
|
-
|
|
1012
|
-
adj_x += (T(1) - tanh_x*tanh_x)*adj_ret;\
|
|
967
|
+
adj_x += (T(1) - ret*ret)*adj_ret;\
|
|
1013
968
|
}\
|
|
1014
|
-
inline CUDA_CALLABLE void adj_sqrt(T x, T& adj_x, T adj_ret)\
|
|
969
|
+
inline CUDA_CALLABLE void adj_sqrt(T x, T ret, T& adj_x, T adj_ret)\
|
|
1015
970
|
{\
|
|
1016
|
-
adj_x += T(0.5)*(T(1)/
|
|
971
|
+
adj_x += T(0.5)*(T(1)/ret)*adj_ret;\
|
|
1017
972
|
DO_IF_FPCHECK(if (!isfinite(adj_x))\
|
|
1018
973
|
{\
|
|
1019
974
|
printf("%s:%d - adj_sqrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
|
|
1020
975
|
assert(0);\
|
|
1021
976
|
})\
|
|
1022
977
|
}\
|
|
978
|
+
inline CUDA_CALLABLE void adj_cbrt(T x, T ret, T& adj_x, T adj_ret)\
|
|
979
|
+
{\
|
|
980
|
+
adj_x += (T(1)/T(3))*(T(1)/(ret*ret))*adj_ret;\
|
|
981
|
+
DO_IF_FPCHECK(if (!isfinite(adj_x))\
|
|
982
|
+
{\
|
|
983
|
+
printf("%s:%d - adj_cbrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
|
|
984
|
+
assert(0);\
|
|
985
|
+
})\
|
|
986
|
+
}\
|
|
1023
987
|
inline CUDA_CALLABLE void adj_degrees(T x, T& adj_x, T adj_ret)\
|
|
1024
988
|
{\
|
|
1025
989
|
adj_x += RAD_TO_DEG * adj_ret;\
|
|
@@ -1027,7 +991,13 @@ inline CUDA_CALLABLE void adj_degrees(T x, T& adj_x, T adj_ret)\
|
|
|
1027
991
|
inline CUDA_CALLABLE void adj_radians(T x, T& adj_x, T adj_ret)\
|
|
1028
992
|
{\
|
|
1029
993
|
adj_x += DEG_TO_RAD * adj_ret;\
|
|
1030
|
-
}
|
|
994
|
+
}\
|
|
995
|
+
inline CUDA_CALLABLE void adj_round(T x, T& adj_x, T adj_ret){ }\
|
|
996
|
+
inline CUDA_CALLABLE void adj_rint(T x, T& adj_x, T adj_ret){ }\
|
|
997
|
+
inline CUDA_CALLABLE void adj_trunc(T x, T& adj_x, T adj_ret){ }\
|
|
998
|
+
inline CUDA_CALLABLE void adj_floor(T x, T& adj_x, T adj_ret){ }\
|
|
999
|
+
inline CUDA_CALLABLE void adj_ceil(T x, T& adj_x, T adj_ret){ }\
|
|
1000
|
+
inline CUDA_CALLABLE void adj_frac(T x, T& adj_x, T adj_ret){ }
|
|
1031
1001
|
|
|
1032
1002
|
DECLARE_ADJOINTS(float16)
|
|
1033
1003
|
DECLARE_ADJOINTS(float32)
|
|
@@ -1051,17 +1021,31 @@ CUDA_CALLABLE inline void adj_select(const C& cond, const T& a, const T& b, C& a
|
|
|
1051
1021
|
}
|
|
1052
1022
|
|
|
1053
1023
|
template <typename T>
|
|
1054
|
-
CUDA_CALLABLE inline
|
|
1024
|
+
CUDA_CALLABLE inline T copy(const T& src)
|
|
1025
|
+
{
|
|
1026
|
+
return src;
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1029
|
+
template <typename T>
|
|
1030
|
+
CUDA_CALLABLE inline void adj_copy(const T& src, T& adj_src, T& adj_dest)
|
|
1031
|
+
{
|
|
1032
|
+
adj_src = adj_dest;
|
|
1033
|
+
adj_dest = T{};
|
|
1034
|
+
}
|
|
1035
|
+
|
|
1036
|
+
template <typename T>
|
|
1037
|
+
CUDA_CALLABLE inline void assign(T& dest, const T& src)
|
|
1055
1038
|
{
|
|
1056
1039
|
dest = src;
|
|
1057
1040
|
}
|
|
1058
1041
|
|
|
1059
1042
|
template <typename T>
|
|
1060
|
-
CUDA_CALLABLE inline void
|
|
1043
|
+
CUDA_CALLABLE inline void adj_assign(T& dest, const T& src, T& adj_dest, T& adj_src)
|
|
1061
1044
|
{
|
|
1062
|
-
//
|
|
1045
|
+
// this is generally a non-differentiable operation since it violates SSA,
|
|
1046
|
+
// except in read-modify-write statements which are reversible through backpropagation
|
|
1063
1047
|
adj_src = adj_dest;
|
|
1064
|
-
adj_dest = T
|
|
1048
|
+
adj_dest = T{};
|
|
1065
1049
|
}
|
|
1066
1050
|
|
|
1067
1051
|
|
|
@@ -1106,34 +1090,8 @@ struct launch_bounds_t
|
|
|
1106
1090
|
size_t size; // total number of threads
|
|
1107
1091
|
};
|
|
1108
1092
|
|
|
1109
|
-
#
|
|
1110
|
-
|
|
1111
|
-
// store launch bounds in shared memory so
|
|
1112
|
-
// we can access them from any user func
|
|
1113
|
-
// this is to avoid having to explicitly
|
|
1114
|
-
// set another piece of __constant__ memory
|
|
1115
|
-
// from the host
|
|
1116
|
-
__shared__ launch_bounds_t s_launchBounds;
|
|
1117
|
-
|
|
1118
|
-
__device__ inline void set_launch_bounds(const launch_bounds_t& b)
|
|
1119
|
-
{
|
|
1120
|
-
if (threadIdx.x == 0)
|
|
1121
|
-
s_launchBounds = b;
|
|
1122
|
-
|
|
1123
|
-
__syncthreads();
|
|
1124
|
-
}
|
|
1125
|
-
|
|
1126
|
-
#else
|
|
1127
|
-
|
|
1128
|
-
// for single-threaded CPU we store launch
|
|
1129
|
-
// bounds in static memory to share globally
|
|
1130
|
-
static launch_bounds_t s_launchBounds;
|
|
1093
|
+
#ifndef __CUDACC__
|
|
1131
1094
|
static size_t s_threadIdx;
|
|
1132
|
-
|
|
1133
|
-
inline void set_launch_bounds(const launch_bounds_t& b)
|
|
1134
|
-
{
|
|
1135
|
-
s_launchBounds = b;
|
|
1136
|
-
}
|
|
1137
1095
|
#endif
|
|
1138
1096
|
|
|
1139
1097
|
inline CUDA_CALLABLE size_t grid_index()
|
|
@@ -1147,10 +1105,8 @@ inline CUDA_CALLABLE size_t grid_index()
|
|
|
1147
1105
|
#endif
|
|
1148
1106
|
}
|
|
1149
1107
|
|
|
1150
|
-
inline CUDA_CALLABLE int tid()
|
|
1108
|
+
inline CUDA_CALLABLE int tid(size_t index)
|
|
1151
1109
|
{
|
|
1152
|
-
const size_t index = grid_index();
|
|
1153
|
-
|
|
1154
1110
|
// For the 1-D tid() we need to warn the user if we're about to provide a truncated index
|
|
1155
1111
|
// Only do this in _DEBUG when called from device to avoid excessive register allocation
|
|
1156
1112
|
#if defined(_DEBUG) || !defined(__CUDA_ARCH__)
|
|
@@ -1161,23 +1117,19 @@ inline CUDA_CALLABLE int tid()
|
|
|
1161
1117
|
return static_cast<int>(index);
|
|
1162
1118
|
}
|
|
1163
1119
|
|
|
1164
|
-
inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j)
|
|
1120
|
+
inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, size_t index, const launch_bounds_t& launch_bounds)
|
|
1165
1121
|
{
|
|
1166
|
-
const size_t
|
|
1167
|
-
|
|
1168
|
-
const int n = s_launchBounds.shape[1];
|
|
1122
|
+
const size_t n = launch_bounds.shape[1];
|
|
1169
1123
|
|
|
1170
1124
|
// convert to work item
|
|
1171
1125
|
i = index/n;
|
|
1172
1126
|
j = index%n;
|
|
1173
1127
|
}
|
|
1174
1128
|
|
|
1175
|
-
inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k)
|
|
1129
|
+
inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, size_t index, const launch_bounds_t& launch_bounds)
|
|
1176
1130
|
{
|
|
1177
|
-
const size_t
|
|
1178
|
-
|
|
1179
|
-
const int n = s_launchBounds.shape[1];
|
|
1180
|
-
const int o = s_launchBounds.shape[2];
|
|
1131
|
+
const size_t n = launch_bounds.shape[1];
|
|
1132
|
+
const size_t o = launch_bounds.shape[2];
|
|
1181
1133
|
|
|
1182
1134
|
// convert to work item
|
|
1183
1135
|
i = index/(n*o);
|
|
@@ -1185,13 +1137,11 @@ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k)
|
|
|
1185
1137
|
k = index%o;
|
|
1186
1138
|
}
|
|
1187
1139
|
|
|
1188
|
-
inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l)
|
|
1140
|
+
inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l, size_t index, const launch_bounds_t& launch_bounds)
|
|
1189
1141
|
{
|
|
1190
|
-
const size_t
|
|
1191
|
-
|
|
1192
|
-
const
|
|
1193
|
-
const int o = s_launchBounds.shape[2];
|
|
1194
|
-
const int p = s_launchBounds.shape[3];
|
|
1142
|
+
const size_t n = launch_bounds.shape[1];
|
|
1143
|
+
const size_t o = launch_bounds.shape[2];
|
|
1144
|
+
const size_t p = launch_bounds.shape[3];
|
|
1195
1145
|
|
|
1196
1146
|
// convert to work item
|
|
1197
1147
|
i = index/(n*o*p);
|
|
@@ -1203,11 +1153,11 @@ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l)
|
|
|
1203
1153
|
template<typename T>
|
|
1204
1154
|
inline CUDA_CALLABLE T atomic_add(T* buf, T value)
|
|
1205
1155
|
{
|
|
1206
|
-
#if defined(
|
|
1156
|
+
#if !defined(__CUDA_ARCH__)
|
|
1207
1157
|
T old = buf[0];
|
|
1208
1158
|
buf[0] += value;
|
|
1209
1159
|
return old;
|
|
1210
|
-
#
|
|
1160
|
+
#else
|
|
1211
1161
|
return atomicAdd(buf, value);
|
|
1212
1162
|
#endif
|
|
1213
1163
|
}
|
|
@@ -1215,11 +1165,14 @@ inline CUDA_CALLABLE T atomic_add(T* buf, T value)
|
|
|
1215
1165
|
template<>
|
|
1216
1166
|
inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
|
|
1217
1167
|
{
|
|
1218
|
-
#if defined(
|
|
1168
|
+
#if !defined(__CUDA_ARCH__)
|
|
1219
1169
|
float16 old = buf[0];
|
|
1220
1170
|
buf[0] += value;
|
|
1221
1171
|
return old;
|
|
1222
|
-
#elif defined(
|
|
1172
|
+
#elif defined(__clang__) // CUDA compiled by Clang
|
|
1173
|
+
__half r = atomicAdd(reinterpret_cast<__half*>(buf), *reinterpret_cast<__half*>(&value));
|
|
1174
|
+
return *reinterpret_cast<float16*>(&r);
|
|
1175
|
+
#else // CUDA compiled by NVRTC
|
|
1223
1176
|
//return atomicAdd(buf, value);
|
|
1224
1177
|
|
|
1225
1178
|
/* Define __PTR for atomicAdd prototypes below, undef after done */
|
|
@@ -1243,7 +1196,7 @@ inline CUDA_CALLABLE float16 atomic_add(float16* buf, float16 value)
|
|
|
1243
1196
|
|
|
1244
1197
|
#undef __PTR
|
|
1245
1198
|
|
|
1246
|
-
#endif
|
|
1199
|
+
#endif // CUDA compiled by NVRTC
|
|
1247
1200
|
|
|
1248
1201
|
}
|
|
1249
1202
|
|
|
@@ -1318,9 +1271,36 @@ inline CUDA_CALLABLE int atomic_min(int* address, int val)
|
|
|
1318
1271
|
#endif
|
|
1319
1272
|
}
|
|
1320
1273
|
|
|
1274
|
+
// default behavior for adjoint of atomic min/max operation that accumulates gradients for all elements matching the min/max value
|
|
1275
|
+
template <typename T>
|
|
1276
|
+
CUDA_CALLABLE inline void adj_atomic_minmax(T *addr, T *adj_addr, const T &value, T &adj_value)
|
|
1277
|
+
{
|
|
1278
|
+
if (value == *addr)
|
|
1279
|
+
adj_value += *adj_addr;
|
|
1280
|
+
}
|
|
1281
|
+
|
|
1282
|
+
// for integral types we do not accumulate gradients
|
|
1283
|
+
CUDA_CALLABLE inline void adj_atomic_minmax(int8* buf, int8* adj_buf, const int8 &value, int8 &adj_value) { }
|
|
1284
|
+
CUDA_CALLABLE inline void adj_atomic_minmax(uint8* buf, uint8* adj_buf, const uint8 &value, uint8 &adj_value) { }
|
|
1285
|
+
CUDA_CALLABLE inline void adj_atomic_minmax(int16* buf, int16* adj_buf, const int16 &value, int16 &adj_value) { }
|
|
1286
|
+
CUDA_CALLABLE inline void adj_atomic_minmax(uint16* buf, uint16* adj_buf, const uint16 &value, uint16 &adj_value) { }
|
|
1287
|
+
CUDA_CALLABLE inline void adj_atomic_minmax(int32* buf, int32* adj_buf, const int32 &value, int32 &adj_value) { }
|
|
1288
|
+
CUDA_CALLABLE inline void adj_atomic_minmax(uint32* buf, uint32* adj_buf, const uint32 &value, uint32 &adj_value) { }
|
|
1289
|
+
CUDA_CALLABLE inline void adj_atomic_minmax(int64* buf, int64* adj_buf, const int64 &value, int64 &adj_value) { }
|
|
1290
|
+
CUDA_CALLABLE inline void adj_atomic_minmax(uint64* buf, uint64* adj_buf, const uint64 &value, uint64 &adj_value) { }
|
|
1291
|
+
CUDA_CALLABLE inline void adj_atomic_minmax(bool* buf, bool* adj_buf, const bool &value, bool &adj_value) { }
|
|
1292
|
+
|
|
1321
1293
|
|
|
1322
1294
|
} // namespace wp
|
|
1323
1295
|
|
|
1296
|
+
|
|
1297
|
+
// bool and printf are defined outside of the wp namespace in crt.h, hence
|
|
1298
|
+
// their adjoint counterparts are also defined in the global namespace.
|
|
1299
|
+
template <typename T>
|
|
1300
|
+
CUDA_CALLABLE inline void adj_bool(T, T&, bool) {}
|
|
1301
|
+
inline CUDA_CALLABLE void adj_printf(const char* fmt, ...) {}
|
|
1302
|
+
|
|
1303
|
+
|
|
1324
1304
|
#include "vec.h"
|
|
1325
1305
|
#include "mat.h"
|
|
1326
1306
|
#include "quat.h"
|
|
@@ -1485,10 +1465,6 @@ inline CUDA_CALLABLE void adj_print(transform_t<Type> t, transform_t<Type>& adj_
|
|
|
1485
1465
|
inline CUDA_CALLABLE void adj_print(str t, str& adj_t) {}
|
|
1486
1466
|
|
|
1487
1467
|
|
|
1488
|
-
// printf defined globally in crt.h
|
|
1489
|
-
inline CUDA_CALLABLE void adj_printf(const char* fmt, ...) {}
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
1468
|
template <typename T>
|
|
1493
1469
|
inline CUDA_CALLABLE void expect_eq(const T& actual, const T& expected)
|
|
1494
1470
|
{
|
|
@@ -1528,7 +1504,7 @@ inline CUDA_CALLABLE void expect_near(const T& actual, const T& expected, const
|
|
|
1528
1504
|
{
|
|
1529
1505
|
if (abs(actual - expected) > tolerance)
|
|
1530
1506
|
{
|
|
1531
|
-
printf("Error, expect_near() failed with
|
|
1507
|
+
printf("Error, expect_near() failed with tolerance "); print(tolerance);
|
|
1532
1508
|
printf("\t Expected: "); print(expected);
|
|
1533
1509
|
printf("\t Actual: "); print(actual);
|
|
1534
1510
|
}
|
|
@@ -1539,7 +1515,7 @@ inline CUDA_CALLABLE void expect_near(const vec3& actual, const vec3& expected,
|
|
|
1539
1515
|
const float diff = max(max(abs(actual[0] - expected[0]), abs(actual[1] - expected[1])), abs(actual[2] - expected[2]));
|
|
1540
1516
|
if (diff > tolerance)
|
|
1541
1517
|
{
|
|
1542
|
-
printf("Error, expect_near() failed with
|
|
1518
|
+
printf("Error, expect_near() failed with tolerance "); print(tolerance);
|
|
1543
1519
|
printf("\t Expected: "); print(expected);
|
|
1544
1520
|
printf("\t Actual: "); print(actual);
|
|
1545
1521
|
}
|