warp-lang 1.0.2__py3-none-manylinux2014_aarch64.whl → 1.1.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +115 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3425 -3354
- warp/codegen.py +2878 -2792
- warp/config.py +40 -36
- warp/constants.py +45 -45
- warp/context.py +5194 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +383 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -277
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
- warp/examples/benchmarks/benchmark_launches.py +295 -295
- warp/examples/browse.py +29 -29
- warp/examples/core/example_dem.py +234 -219
- warp/examples/core/example_fluid.py +293 -267
- warp/examples/core/example_graph_capture.py +144 -126
- warp/examples/core/example_marching_cubes.py +188 -174
- warp/examples/core/example_mesh.py +174 -155
- warp/examples/core/example_mesh_intersect.py +205 -193
- warp/examples/core/example_nvdb.py +176 -170
- warp/examples/core/example_raycast.py +105 -90
- warp/examples/core/example_raymarch.py +199 -178
- warp/examples/core/example_render_opengl.py +185 -141
- warp/examples/core/example_sph.py +405 -387
- warp/examples/core/example_torch.py +222 -181
- warp/examples/core/example_wave.py +263 -248
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +407 -389
- warp/examples/fem/example_convection_diffusion.py +182 -168
- warp/examples/fem/example_convection_diffusion_dg.py +219 -209
- warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
- warp/examples/fem/example_deformed_geometry.py +177 -159
- warp/examples/fem/example_diffusion.py +201 -173
- warp/examples/fem/example_diffusion_3d.py +177 -152
- warp/examples/fem/example_diffusion_mgpu.py +221 -214
- warp/examples/fem/example_mixed_elasticity.py +244 -222
- warp/examples/fem/example_navier_stokes.py +259 -243
- warp/examples/fem/example_stokes.py +220 -192
- warp/examples/fem/example_stokes_transfer.py +265 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +260 -246
- warp/examples/optim/example_cloth_throw.py +222 -209
- warp/examples/optim/example_diffray.py +566 -536
- warp/examples/optim/example_drone.py +864 -835
- warp/examples/optim/example_inverse_kinematics.py +176 -168
- warp/examples/optim/example_inverse_kinematics_torch.py +185 -169
- warp/examples/optim/example_spring_cage.py +239 -231
- warp/examples/optim/example_trajectory.py +223 -199
- warp/examples/optim/example_walker.py +306 -293
- warp/examples/sim/example_cartpole.py +139 -129
- warp/examples/sim/example_cloth.py +196 -186
- warp/examples/sim/example_granular.py +124 -111
- warp/examples/sim/example_granular_collision_sdf.py +197 -186
- warp/examples/sim/example_jacobian_ik.py +236 -214
- warp/examples/sim/example_particle_chain.py +118 -105
- warp/examples/sim/example_quadruped.py +193 -180
- warp/examples/sim/example_rigid_chain.py +197 -187
- warp/examples/sim/example_rigid_contact.py +189 -177
- warp/examples/sim/example_rigid_force.py +127 -125
- warp/examples/sim/example_rigid_gyroscopic.py +109 -95
- warp/examples/sim/example_rigid_soft_contact.py +134 -122
- warp/examples/sim/example_soft_body.py +190 -177
- warp/fabric.py +337 -335
- warp/fem/__init__.py +60 -27
- warp/fem/cache.py +401 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +15 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +744 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +441 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1630 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +294 -292
- warp/fem/space/basis_space.py +488 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +267 -267
- warp/fem/space/grid_3d_function_space.py +305 -306
- warp/fem/space/hexmesh_function_space.py +350 -352
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +368 -369
- warp/fem/space/restriction.py +158 -160
- warp/fem/space/shape/__init__.py +13 -15
- warp/fem/space/shape/cube_shape_function.py +738 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +294 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +223 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1072 -1025
- warp/native/builtin.h +1560 -1560
- warp/native/bvh.cpp +398 -398
- warp/native/bvh.cu +525 -525
- warp/native/bvh.h +429 -429
- warp/native/clang/clang.cpp +495 -464
- warp/native/crt.cpp +31 -31
- warp/native/crt.h +334 -334
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1498 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +293 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/NanoVDB.h +4782 -4782
- warp/native/nanovdb/PNanoVDB.h +2553 -2553
- warp/native/nanovdb/PNanoVDBWrite.h +294 -294
- warp/native/noise.h +850 -850
- warp/native/quat.h +1084 -1084
- warp/native/rand.h +299 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1132 -1132
- warp/native/volume.cpp +297 -297
- warp/native/volume.cu +32 -32
- warp/native/volume.h +538 -538
- warp/native/volume_builder.cu +425 -425
- warp/native/volume_builder.h +19 -19
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2943 -2828
- warp/native/warp.h +313 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3217 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1910 -1991
- warp/sim/integrator_xpbd.py +3294 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1227 -1227
- warp/stubs.py +2109 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +22 -22
- warp/tests/aux_test_grad_customs.py +23 -23
- warp/tests/aux_test_reference.py +11 -11
- warp/tests/aux_test_reference_reference.py +10 -10
- warp/tests/aux_test_square.py +17 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +239 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +157 -157
- warp/tests/test_arithmetic.py +1124 -1124
- warp/tests/test_array.py +2417 -2326
- warp/tests/test_array_reduce.py +150 -150
- warp/tests/test_async.py +668 -656
- warp/tests/test_atomic.py +141 -141
- warp/tests/test_bool.py +204 -149
- warp/tests/test_builtins_resolution.py +1292 -1292
- warp/tests/test_bvh.py +164 -171
- warp/tests/test_closest_point_edge_edge.py +228 -228
- warp/tests/test_codegen.py +566 -553
- warp/tests/test_compile_consts.py +97 -101
- warp/tests/test_conditional.py +246 -246
- warp/tests/test_copy.py +232 -215
- warp/tests/test_ctypes.py +632 -632
- warp/tests/test_dense.py +67 -67
- warp/tests/test_devices.py +91 -98
- warp/tests/test_dlpack.py +530 -529
- warp/tests/test_examples.py +400 -378
- warp/tests/test_fabricarray.py +955 -955
- warp/tests/test_fast_math.py +62 -54
- warp/tests/test_fem.py +1277 -1278
- warp/tests/test_fp16.py +130 -130
- warp/tests/test_func.py +338 -337
- warp/tests/test_generics.py +571 -571
- warp/tests/test_grad.py +746 -640
- warp/tests/test_grad_customs.py +333 -336
- warp/tests/test_hash_grid.py +210 -164
- warp/tests/test_import.py +39 -39
- warp/tests/test_indexedarray.py +1134 -1134
- warp/tests/test_intersect.py +67 -67
- warp/tests/test_jax.py +307 -307
- warp/tests/test_large.py +167 -164
- warp/tests/test_launch.py +354 -354
- warp/tests/test_lerp.py +261 -261
- warp/tests/test_linear_solvers.py +191 -171
- warp/tests/test_lvalue.py +421 -493
- warp/tests/test_marching_cubes.py +65 -65
- warp/tests/test_mat.py +1801 -1827
- warp/tests/test_mat_lite.py +115 -115
- warp/tests/test_mat_scalar_ops.py +2907 -2889
- warp/tests/test_math.py +126 -193
- warp/tests/test_matmul.py +500 -499
- warp/tests/test_matmul_lite.py +410 -410
- warp/tests/test_mempool.py +188 -190
- warp/tests/test_mesh.py +284 -324
- warp/tests/test_mesh_query_aabb.py +228 -241
- warp/tests/test_mesh_query_point.py +692 -702
- warp/tests/test_mesh_query_ray.py +292 -303
- warp/tests/test_mlp.py +276 -276
- warp/tests/test_model.py +110 -110
- warp/tests/test_modules_lite.py +39 -39
- warp/tests/test_multigpu.py +163 -163
- warp/tests/test_noise.py +248 -248
- warp/tests/test_operators.py +250 -250
- warp/tests/test_options.py +123 -125
- warp/tests/test_peer.py +133 -137
- warp/tests/test_pinned.py +78 -78
- warp/tests/test_print.py +54 -54
- warp/tests/test_quat.py +2086 -2086
- warp/tests/test_rand.py +288 -288
- warp/tests/test_reload.py +217 -217
- warp/tests/test_rounding.py +179 -179
- warp/tests/test_runlength_encode.py +190 -190
- warp/tests/test_sim_grad.py +243 -0
- warp/tests/test_sim_kinematics.py +91 -97
- warp/tests/test_smoothstep.py +168 -168
- warp/tests/test_snippet.py +305 -266
- warp/tests/test_sparse.py +468 -460
- warp/tests/test_spatial.py +2148 -2148
- warp/tests/test_streams.py +486 -473
- warp/tests/test_struct.py +710 -675
- warp/tests/test_tape.py +173 -148
- warp/tests/test_torch.py +743 -743
- warp/tests/test_transient_module.py +87 -87
- warp/tests/test_types.py +556 -659
- warp/tests/test_utils.py +490 -499
- warp/tests/test_vec.py +1264 -1268
- warp/tests/test_vec_lite.py +73 -73
- warp/tests/test_vec_scalar_ops.py +2099 -2099
- warp/tests/test_verify_fp.py +94 -94
- warp/tests/test_volume.py +737 -736
- warp/tests/test_volume_write.py +255 -265
- warp/tests/unittest_serial.py +37 -37
- warp/tests/unittest_suites.py +363 -359
- warp/tests/unittest_utils.py +603 -578
- warp/tests/unused_test_misc.py +71 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +566 -561
- warp/torch.py +321 -295
- warp/types.py +4504 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
- warp_lang-1.1.0.dist-info/RECORD +352 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp_lang-1.0.2.dist-info/RECORD +0 -352
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/native/reduce.cu
CHANGED
|
@@ -1,348 +1,348 @@
|
|
|
1
|
-
|
|
2
|
-
#include "cuda_util.h"
|
|
3
|
-
#include "warp.h"
|
|
4
|
-
|
|
5
|
-
#include "temp_buffer.h"
|
|
6
|
-
|
|
7
|
-
#define THRUST_IGNORE_CUB_VERSION_CHECK
|
|
8
|
-
#include <cub/device/device_reduce.cuh>
|
|
9
|
-
#include <cub/iterator/counting_input_iterator.cuh>
|
|
10
|
-
|
|
11
|
-
namespace
|
|
12
|
-
{
|
|
13
|
-
|
|
14
|
-
template <typename T>
|
|
15
|
-
__global__ void cwise_mult_kernel(int len, int stride_a, int stride_b, const T *a, const T *b, T *out)
|
|
16
|
-
{
|
|
17
|
-
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
18
|
-
if (i >= len)
|
|
19
|
-
return;
|
|
20
|
-
out[i] = a[i * stride_a] * b[i * stride_b];
|
|
21
|
-
}
|
|
22
|
-
|
|
23
|
-
/// Custom iterator for allowing strided access with CUB
|
|
24
|
-
template <typename T> struct cub_strided_iterator
|
|
25
|
-
{
|
|
26
|
-
typedef cub_strided_iterator<T> self_type;
|
|
27
|
-
typedef std::ptrdiff_t difference_type;
|
|
28
|
-
typedef T value_type;
|
|
29
|
-
typedef T *pointer;
|
|
30
|
-
typedef T &reference;
|
|
31
|
-
|
|
32
|
-
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
|
|
33
|
-
|
|
34
|
-
T *ptr = nullptr;
|
|
35
|
-
int stride = 1;
|
|
36
|
-
|
|
37
|
-
CUDA_CALLABLE self_type operator++(int)
|
|
38
|
-
{
|
|
39
|
-
return ++(self_type(*this));
|
|
40
|
-
}
|
|
41
|
-
|
|
42
|
-
CUDA_CALLABLE self_type &operator++()
|
|
43
|
-
{
|
|
44
|
-
ptr += stride;
|
|
45
|
-
return *this;
|
|
46
|
-
}
|
|
47
|
-
|
|
48
|
-
__host__ __device__ __forceinline__ reference operator*() const
|
|
49
|
-
{
|
|
50
|
-
return *ptr;
|
|
51
|
-
}
|
|
52
|
-
|
|
53
|
-
CUDA_CALLABLE self_type operator+(difference_type n) const
|
|
54
|
-
{
|
|
55
|
-
return self_type(*this) += n;
|
|
56
|
-
}
|
|
57
|
-
|
|
58
|
-
CUDA_CALLABLE self_type &operator+=(difference_type n)
|
|
59
|
-
{
|
|
60
|
-
ptr += n * stride;
|
|
61
|
-
return *this;
|
|
62
|
-
}
|
|
63
|
-
|
|
64
|
-
CUDA_CALLABLE self_type operator-(difference_type n) const
|
|
65
|
-
{
|
|
66
|
-
return self_type(*this) -= n;
|
|
67
|
-
}
|
|
68
|
-
|
|
69
|
-
CUDA_CALLABLE self_type &operator-=(difference_type n)
|
|
70
|
-
{
|
|
71
|
-
ptr -= n * stride;
|
|
72
|
-
return *this;
|
|
73
|
-
}
|
|
74
|
-
|
|
75
|
-
CUDA_CALLABLE difference_type operator-(const self_type &other) const
|
|
76
|
-
{
|
|
77
|
-
return (ptr - other.ptr) / stride;
|
|
78
|
-
}
|
|
79
|
-
|
|
80
|
-
CUDA_CALLABLE reference operator[](difference_type n) const
|
|
81
|
-
{
|
|
82
|
-
return *(ptr + n * stride);
|
|
83
|
-
}
|
|
84
|
-
|
|
85
|
-
CUDA_CALLABLE pointer operator->() const
|
|
86
|
-
{
|
|
87
|
-
return ptr;
|
|
88
|
-
}
|
|
89
|
-
|
|
90
|
-
CUDA_CALLABLE bool operator==(const self_type &rhs) const
|
|
91
|
-
{
|
|
92
|
-
return (ptr == rhs.ptr);
|
|
93
|
-
}
|
|
94
|
-
|
|
95
|
-
CUDA_CALLABLE bool operator!=(const self_type &rhs) const
|
|
96
|
-
{
|
|
97
|
-
return (ptr != rhs.ptr);
|
|
98
|
-
}
|
|
99
|
-
};
|
|
100
|
-
|
|
101
|
-
template <typename T> void array_sum_device(const T *ptr_a, T *ptr_out, int count, int byte_stride, int type_length)
|
|
102
|
-
{
|
|
103
|
-
assert((byte_stride % sizeof(T)) == 0);
|
|
104
|
-
const int stride = byte_stride / sizeof(T);
|
|
105
|
-
|
|
106
|
-
ContextGuard guard(cuda_context_get_current());
|
|
107
|
-
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
108
|
-
|
|
109
|
-
cub_strided_iterator<const T> ptr_strided{ptr_a, stride};
|
|
110
|
-
|
|
111
|
-
size_t buff_size = 0;
|
|
112
|
-
check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, ptr_strided, ptr_out, count, stream));
|
|
113
|
-
void* temp_buffer = alloc_device(WP_CURRENT_CONTEXT, buff_size);
|
|
114
|
-
|
|
115
|
-
for (int k = 0; k < type_length; ++k)
|
|
116
|
-
{
|
|
117
|
-
cub_strided_iterator<const T> ptr_strided{ptr_a + k, stride};
|
|
118
|
-
check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, ptr_strided, ptr_out + k, count, stream));
|
|
119
|
-
}
|
|
120
|
-
|
|
121
|
-
free_device(WP_CURRENT_CONTEXT, temp_buffer);
|
|
122
|
-
}
|
|
123
|
-
|
|
124
|
-
template <typename T>
|
|
125
|
-
void array_sum_device_dispatch(const T *ptr_a, T *ptr_out, int count, int byte_stride, int type_length)
|
|
126
|
-
{
|
|
127
|
-
using vec2 = wp::vec_t<2, T>;
|
|
128
|
-
using vec3 = wp::vec_t<3, T>;
|
|
129
|
-
using vec4 = wp::vec_t<4, T>;
|
|
130
|
-
|
|
131
|
-
// specialized calls for common vector types
|
|
132
|
-
|
|
133
|
-
if ((type_length % 4) == 0 && (byte_stride % sizeof(vec4)) == 0)
|
|
134
|
-
{
|
|
135
|
-
return array_sum_device(reinterpret_cast<const vec4 *>(ptr_a), reinterpret_cast<vec4 *>(ptr_out), count,
|
|
136
|
-
byte_stride, type_length / 4);
|
|
137
|
-
}
|
|
138
|
-
|
|
139
|
-
if ((type_length % 3) == 0 && (byte_stride % sizeof(vec3)) == 0)
|
|
140
|
-
{
|
|
141
|
-
return array_sum_device(reinterpret_cast<const vec3 *>(ptr_a), reinterpret_cast<vec3 *>(ptr_out), count,
|
|
142
|
-
byte_stride, type_length / 3);
|
|
143
|
-
}
|
|
144
|
-
|
|
145
|
-
if ((type_length % 2) == 0 && (byte_stride % sizeof(vec2)) == 0)
|
|
146
|
-
{
|
|
147
|
-
return array_sum_device(reinterpret_cast<const vec2 *>(ptr_a), reinterpret_cast<vec2 *>(ptr_out), count,
|
|
148
|
-
byte_stride, type_length / 2);
|
|
149
|
-
}
|
|
150
|
-
|
|
151
|
-
return array_sum_device(ptr_a, ptr_out, count, byte_stride, type_length);
|
|
152
|
-
}
|
|
153
|
-
|
|
154
|
-
template <typename T> CUDA_CALLABLE T element_inner_product(const T &a, const T &b)
|
|
155
|
-
{
|
|
156
|
-
return a * b;
|
|
157
|
-
}
|
|
158
|
-
|
|
159
|
-
template <unsigned Length, typename T>
|
|
160
|
-
CUDA_CALLABLE T element_inner_product(const wp::vec_t<Length, T> &a, const wp::vec_t<Length, T> &b)
|
|
161
|
-
{
|
|
162
|
-
return wp::dot(a, b);
|
|
163
|
-
}
|
|
164
|
-
|
|
165
|
-
/// Custom iterator for allowing strided access with CUB
|
|
166
|
-
template <typename ElemT, typename ScalarT> struct cub_inner_product_iterator
|
|
167
|
-
{
|
|
168
|
-
typedef cub_inner_product_iterator<ElemT, ScalarT> self_type;
|
|
169
|
-
typedef std::ptrdiff_t difference_type;
|
|
170
|
-
typedef ScalarT value_type;
|
|
171
|
-
typedef ScalarT *pointer;
|
|
172
|
-
typedef ScalarT reference;
|
|
173
|
-
|
|
174
|
-
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
|
|
175
|
-
|
|
176
|
-
const ElemT *ptr_a = nullptr;
|
|
177
|
-
const ElemT *ptr_b = nullptr;
|
|
178
|
-
|
|
179
|
-
int stride_a = 1;
|
|
180
|
-
int stride_b = 1;
|
|
181
|
-
int type_length = 1;
|
|
182
|
-
|
|
183
|
-
CUDA_CALLABLE self_type operator++(int)
|
|
184
|
-
{
|
|
185
|
-
return ++(self_type(*this));
|
|
186
|
-
}
|
|
187
|
-
|
|
188
|
-
CUDA_CALLABLE self_type &operator++()
|
|
189
|
-
{
|
|
190
|
-
ptr_a += stride_a;
|
|
191
|
-
ptr_b += stride_b;
|
|
192
|
-
return *this;
|
|
193
|
-
}
|
|
194
|
-
|
|
195
|
-
__host__ __device__ __forceinline__ reference operator*() const
|
|
196
|
-
{
|
|
197
|
-
return compute_value(0);
|
|
198
|
-
}
|
|
199
|
-
|
|
200
|
-
CUDA_CALLABLE self_type operator+(difference_type n) const
|
|
201
|
-
{
|
|
202
|
-
return self_type(*this) += n;
|
|
203
|
-
}
|
|
204
|
-
|
|
205
|
-
CUDA_CALLABLE self_type &operator+=(difference_type n)
|
|
206
|
-
{
|
|
207
|
-
ptr_a += n * stride_a;
|
|
208
|
-
ptr_b += n * stride_b;
|
|
209
|
-
return *this;
|
|
210
|
-
}
|
|
211
|
-
|
|
212
|
-
CUDA_CALLABLE self_type operator-(difference_type n) const
|
|
213
|
-
{
|
|
214
|
-
return self_type(*this) -= n;
|
|
215
|
-
}
|
|
216
|
-
|
|
217
|
-
CUDA_CALLABLE self_type &operator-=(difference_type n)
|
|
218
|
-
{
|
|
219
|
-
ptr_a -= n * stride_a;
|
|
220
|
-
ptr_b -= n * stride_b;
|
|
221
|
-
return *this;
|
|
222
|
-
}
|
|
223
|
-
|
|
224
|
-
CUDA_CALLABLE difference_type operator-(const self_type &other) const
|
|
225
|
-
{
|
|
226
|
-
return (ptr_a - other.ptr_a) / stride_a;
|
|
227
|
-
}
|
|
228
|
-
|
|
229
|
-
CUDA_CALLABLE reference operator[](difference_type n) const
|
|
230
|
-
{
|
|
231
|
-
return compute_value(n);
|
|
232
|
-
}
|
|
233
|
-
|
|
234
|
-
CUDA_CALLABLE bool operator==(const self_type &rhs) const
|
|
235
|
-
{
|
|
236
|
-
return (ptr_a == rhs.ptr_a);
|
|
237
|
-
}
|
|
238
|
-
|
|
239
|
-
CUDA_CALLABLE bool operator!=(const self_type &rhs) const
|
|
240
|
-
{
|
|
241
|
-
return (ptr_a != rhs.ptr_a);
|
|
242
|
-
}
|
|
243
|
-
|
|
244
|
-
private:
|
|
245
|
-
CUDA_CALLABLE ScalarT compute_value(difference_type n) const
|
|
246
|
-
{
|
|
247
|
-
ScalarT val(0);
|
|
248
|
-
const ElemT *a = ptr_a + n * stride_a;
|
|
249
|
-
const ElemT *b = ptr_b + n * stride_b;
|
|
250
|
-
for (int k = 0; k < type_length; ++k)
|
|
251
|
-
{
|
|
252
|
-
val += element_inner_product(a[k], b[k]);
|
|
253
|
-
}
|
|
254
|
-
return val;
|
|
255
|
-
}
|
|
256
|
-
};
|
|
257
|
-
|
|
258
|
-
template <typename ElemT, typename ScalarT>
|
|
259
|
-
void array_inner_device(const ElemT *ptr_a, const ElemT *ptr_b, ScalarT *ptr_out, int count, int byte_stride_a,
|
|
260
|
-
int byte_stride_b, int type_length)
|
|
261
|
-
{
|
|
262
|
-
assert((byte_stride_a % sizeof(ElemT)) == 0);
|
|
263
|
-
assert((byte_stride_b % sizeof(ElemT)) == 0);
|
|
264
|
-
const int stride_a = byte_stride_a / sizeof(ElemT);
|
|
265
|
-
const int stride_b = byte_stride_b / sizeof(ElemT);
|
|
266
|
-
|
|
267
|
-
ContextGuard guard(cuda_context_get_current());
|
|
268
|
-
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
269
|
-
|
|
270
|
-
cub_inner_product_iterator<ElemT, ScalarT> inner_iterator{ptr_a, ptr_b, stride_a, stride_b, type_length};
|
|
271
|
-
|
|
272
|
-
size_t buff_size = 0;
|
|
273
|
-
check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, inner_iterator, ptr_out, count, stream));
|
|
274
|
-
void* temp_buffer = alloc_device(WP_CURRENT_CONTEXT, buff_size);
|
|
275
|
-
|
|
276
|
-
check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, inner_iterator, ptr_out, count, stream));
|
|
277
|
-
|
|
278
|
-
free_device(WP_CURRENT_CONTEXT, temp_buffer);
|
|
279
|
-
}
|
|
280
|
-
|
|
281
|
-
template <typename T>
|
|
282
|
-
void array_inner_device_dispatch(const T *ptr_a, const T *ptr_b, T *ptr_out, int count, int byte_stride_a,
|
|
283
|
-
int byte_stride_b, int type_length)
|
|
284
|
-
{
|
|
285
|
-
using vec2 = wp::vec_t<2, T>;
|
|
286
|
-
using vec3 = wp::vec_t<3, T>;
|
|
287
|
-
using vec4 = wp::vec_t<4, T>;
|
|
288
|
-
|
|
289
|
-
// specialized calls for common vector types
|
|
290
|
-
|
|
291
|
-
if ((type_length % 4) == 0 && (byte_stride_a % sizeof(vec4)) == 0 && (byte_stride_b % sizeof(vec4)) == 0)
|
|
292
|
-
{
|
|
293
|
-
return array_inner_device(reinterpret_cast<const vec4 *>(ptr_a), reinterpret_cast<const vec4 *>(ptr_b), ptr_out,
|
|
294
|
-
count, byte_stride_a, byte_stride_b, type_length / 4);
|
|
295
|
-
}
|
|
296
|
-
|
|
297
|
-
if ((type_length % 3) == 0 && (byte_stride_a % sizeof(vec3)) == 0 && (byte_stride_b % sizeof(vec3)) == 0)
|
|
298
|
-
{
|
|
299
|
-
return array_inner_device(reinterpret_cast<const vec3 *>(ptr_a), reinterpret_cast<const vec3 *>(ptr_b), ptr_out,
|
|
300
|
-
count, byte_stride_a, byte_stride_b, type_length / 3);
|
|
301
|
-
}
|
|
302
|
-
|
|
303
|
-
if ((type_length % 2) == 0 && (byte_stride_a % sizeof(vec2)) == 0 && (byte_stride_b % sizeof(vec2)) == 0)
|
|
304
|
-
{
|
|
305
|
-
return array_inner_device(reinterpret_cast<const vec2 *>(ptr_a), reinterpret_cast<const vec2 *>(ptr_b), ptr_out,
|
|
306
|
-
count, byte_stride_a, byte_stride_b, type_length / 2);
|
|
307
|
-
}
|
|
308
|
-
|
|
309
|
-
return array_inner_device(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_length);
|
|
310
|
-
}
|
|
311
|
-
|
|
312
|
-
} // anonymous namespace
|
|
313
|
-
|
|
314
|
-
void array_inner_float_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
|
|
315
|
-
int type_len)
|
|
316
|
-
{
|
|
317
|
-
void *context = cuda_context_get_current();
|
|
318
|
-
|
|
319
|
-
const float *ptr_a = (const float *)(a);
|
|
320
|
-
const float *ptr_b = (const float *)(b);
|
|
321
|
-
float *ptr_out = (float *)(out);
|
|
322
|
-
|
|
323
|
-
array_inner_device_dispatch(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_len);
|
|
324
|
-
}
|
|
325
|
-
|
|
326
|
-
void array_inner_double_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
|
|
327
|
-
int type_len)
|
|
328
|
-
{
|
|
329
|
-
const double *ptr_a = (const double *)(a);
|
|
330
|
-
const double *ptr_b = (const double *)(b);
|
|
331
|
-
double *ptr_out = (double *)(out);
|
|
332
|
-
|
|
333
|
-
array_inner_device_dispatch(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_len);
|
|
334
|
-
}
|
|
335
|
-
|
|
336
|
-
void array_sum_float_device(uint64_t a, uint64_t out, int count, int byte_stride, int type_length)
|
|
337
|
-
{
|
|
338
|
-
const float *ptr_a = (const float *)(a);
|
|
339
|
-
float *ptr_out = (float *)(out);
|
|
340
|
-
array_sum_device_dispatch(ptr_a, ptr_out, count, byte_stride, type_length);
|
|
341
|
-
}
|
|
342
|
-
|
|
343
|
-
void array_sum_double_device(uint64_t a, uint64_t out, int count, int byte_stride, int type_length)
|
|
344
|
-
{
|
|
345
|
-
const double *ptr_a = (const double *)(a);
|
|
346
|
-
double *ptr_out = (double *)(out);
|
|
347
|
-
array_sum_device_dispatch(ptr_a, ptr_out, count, byte_stride, type_length);
|
|
348
|
-
}
|
|
1
|
+
|
|
2
|
+
#include "cuda_util.h"
|
|
3
|
+
#include "warp.h"
|
|
4
|
+
|
|
5
|
+
#include "temp_buffer.h"
|
|
6
|
+
|
|
7
|
+
#define THRUST_IGNORE_CUB_VERSION_CHECK
|
|
8
|
+
#include <cub/device/device_reduce.cuh>
|
|
9
|
+
#include <cub/iterator/counting_input_iterator.cuh>
|
|
10
|
+
|
|
11
|
+
namespace
|
|
12
|
+
{
|
|
13
|
+
|
|
14
|
+
template <typename T>
|
|
15
|
+
__global__ void cwise_mult_kernel(int len, int stride_a, int stride_b, const T *a, const T *b, T *out)
|
|
16
|
+
{
|
|
17
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
18
|
+
if (i >= len)
|
|
19
|
+
return;
|
|
20
|
+
out[i] = a[i * stride_a] * b[i * stride_b];
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
/// Custom iterator for allowing strided access with CUB
|
|
24
|
+
template <typename T> struct cub_strided_iterator
|
|
25
|
+
{
|
|
26
|
+
typedef cub_strided_iterator<T> self_type;
|
|
27
|
+
typedef std::ptrdiff_t difference_type;
|
|
28
|
+
typedef T value_type;
|
|
29
|
+
typedef T *pointer;
|
|
30
|
+
typedef T &reference;
|
|
31
|
+
|
|
32
|
+
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
|
|
33
|
+
|
|
34
|
+
T *ptr = nullptr;
|
|
35
|
+
int stride = 1;
|
|
36
|
+
|
|
37
|
+
CUDA_CALLABLE self_type operator++(int)
|
|
38
|
+
{
|
|
39
|
+
return ++(self_type(*this));
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
CUDA_CALLABLE self_type &operator++()
|
|
43
|
+
{
|
|
44
|
+
ptr += stride;
|
|
45
|
+
return *this;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
__host__ __device__ __forceinline__ reference operator*() const
|
|
49
|
+
{
|
|
50
|
+
return *ptr;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
CUDA_CALLABLE self_type operator+(difference_type n) const
|
|
54
|
+
{
|
|
55
|
+
return self_type(*this) += n;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
CUDA_CALLABLE self_type &operator+=(difference_type n)
|
|
59
|
+
{
|
|
60
|
+
ptr += n * stride;
|
|
61
|
+
return *this;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
CUDA_CALLABLE self_type operator-(difference_type n) const
|
|
65
|
+
{
|
|
66
|
+
return self_type(*this) -= n;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
CUDA_CALLABLE self_type &operator-=(difference_type n)
|
|
70
|
+
{
|
|
71
|
+
ptr -= n * stride;
|
|
72
|
+
return *this;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
CUDA_CALLABLE difference_type operator-(const self_type &other) const
|
|
76
|
+
{
|
|
77
|
+
return (ptr - other.ptr) / stride;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
CUDA_CALLABLE reference operator[](difference_type n) const
|
|
81
|
+
{
|
|
82
|
+
return *(ptr + n * stride);
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
CUDA_CALLABLE pointer operator->() const
|
|
86
|
+
{
|
|
87
|
+
return ptr;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
CUDA_CALLABLE bool operator==(const self_type &rhs) const
|
|
91
|
+
{
|
|
92
|
+
return (ptr == rhs.ptr);
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
CUDA_CALLABLE bool operator!=(const self_type &rhs) const
|
|
96
|
+
{
|
|
97
|
+
return (ptr != rhs.ptr);
|
|
98
|
+
}
|
|
99
|
+
};
|
|
100
|
+
|
|
101
|
+
template <typename T> void array_sum_device(const T *ptr_a, T *ptr_out, int count, int byte_stride, int type_length)
|
|
102
|
+
{
|
|
103
|
+
assert((byte_stride % sizeof(T)) == 0);
|
|
104
|
+
const int stride = byte_stride / sizeof(T);
|
|
105
|
+
|
|
106
|
+
ContextGuard guard(cuda_context_get_current());
|
|
107
|
+
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
108
|
+
|
|
109
|
+
cub_strided_iterator<const T> ptr_strided{ptr_a, stride};
|
|
110
|
+
|
|
111
|
+
size_t buff_size = 0;
|
|
112
|
+
check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, ptr_strided, ptr_out, count, stream));
|
|
113
|
+
void* temp_buffer = alloc_device(WP_CURRENT_CONTEXT, buff_size);
|
|
114
|
+
|
|
115
|
+
for (int k = 0; k < type_length; ++k)
|
|
116
|
+
{
|
|
117
|
+
cub_strided_iterator<const T> ptr_strided{ptr_a + k, stride};
|
|
118
|
+
check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, ptr_strided, ptr_out + k, count, stream));
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
free_device(WP_CURRENT_CONTEXT, temp_buffer);
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
template <typename T>
|
|
125
|
+
void array_sum_device_dispatch(const T *ptr_a, T *ptr_out, int count, int byte_stride, int type_length)
|
|
126
|
+
{
|
|
127
|
+
using vec2 = wp::vec_t<2, T>;
|
|
128
|
+
using vec3 = wp::vec_t<3, T>;
|
|
129
|
+
using vec4 = wp::vec_t<4, T>;
|
|
130
|
+
|
|
131
|
+
// specialized calls for common vector types
|
|
132
|
+
|
|
133
|
+
if ((type_length % 4) == 0 && (byte_stride % sizeof(vec4)) == 0)
|
|
134
|
+
{
|
|
135
|
+
return array_sum_device(reinterpret_cast<const vec4 *>(ptr_a), reinterpret_cast<vec4 *>(ptr_out), count,
|
|
136
|
+
byte_stride, type_length / 4);
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
if ((type_length % 3) == 0 && (byte_stride % sizeof(vec3)) == 0)
|
|
140
|
+
{
|
|
141
|
+
return array_sum_device(reinterpret_cast<const vec3 *>(ptr_a), reinterpret_cast<vec3 *>(ptr_out), count,
|
|
142
|
+
byte_stride, type_length / 3);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
if ((type_length % 2) == 0 && (byte_stride % sizeof(vec2)) == 0)
|
|
146
|
+
{
|
|
147
|
+
return array_sum_device(reinterpret_cast<const vec2 *>(ptr_a), reinterpret_cast<vec2 *>(ptr_out), count,
|
|
148
|
+
byte_stride, type_length / 2);
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
return array_sum_device(ptr_a, ptr_out, count, byte_stride, type_length);
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
template <typename T> CUDA_CALLABLE T element_inner_product(const T &a, const T &b)
|
|
155
|
+
{
|
|
156
|
+
return a * b;
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
template <unsigned Length, typename T>
|
|
160
|
+
CUDA_CALLABLE T element_inner_product(const wp::vec_t<Length, T> &a, const wp::vec_t<Length, T> &b)
|
|
161
|
+
{
|
|
162
|
+
return wp::dot(a, b);
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
/// Custom iterator for allowing strided access with CUB
|
|
166
|
+
template <typename ElemT, typename ScalarT> struct cub_inner_product_iterator
|
|
167
|
+
{
|
|
168
|
+
typedef cub_inner_product_iterator<ElemT, ScalarT> self_type;
|
|
169
|
+
typedef std::ptrdiff_t difference_type;
|
|
170
|
+
typedef ScalarT value_type;
|
|
171
|
+
typedef ScalarT *pointer;
|
|
172
|
+
typedef ScalarT reference;
|
|
173
|
+
|
|
174
|
+
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
|
|
175
|
+
|
|
176
|
+
const ElemT *ptr_a = nullptr;
|
|
177
|
+
const ElemT *ptr_b = nullptr;
|
|
178
|
+
|
|
179
|
+
int stride_a = 1;
|
|
180
|
+
int stride_b = 1;
|
|
181
|
+
int type_length = 1;
|
|
182
|
+
|
|
183
|
+
CUDA_CALLABLE self_type operator++(int)
|
|
184
|
+
{
|
|
185
|
+
return ++(self_type(*this));
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
CUDA_CALLABLE self_type &operator++()
|
|
189
|
+
{
|
|
190
|
+
ptr_a += stride_a;
|
|
191
|
+
ptr_b += stride_b;
|
|
192
|
+
return *this;
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
__host__ __device__ __forceinline__ reference operator*() const
|
|
196
|
+
{
|
|
197
|
+
return compute_value(0);
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
CUDA_CALLABLE self_type operator+(difference_type n) const
|
|
201
|
+
{
|
|
202
|
+
return self_type(*this) += n;
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
CUDA_CALLABLE self_type &operator+=(difference_type n)
|
|
206
|
+
{
|
|
207
|
+
ptr_a += n * stride_a;
|
|
208
|
+
ptr_b += n * stride_b;
|
|
209
|
+
return *this;
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
CUDA_CALLABLE self_type operator-(difference_type n) const
|
|
213
|
+
{
|
|
214
|
+
return self_type(*this) -= n;
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
CUDA_CALLABLE self_type &operator-=(difference_type n)
|
|
218
|
+
{
|
|
219
|
+
ptr_a -= n * stride_a;
|
|
220
|
+
ptr_b -= n * stride_b;
|
|
221
|
+
return *this;
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
CUDA_CALLABLE difference_type operator-(const self_type &other) const
|
|
225
|
+
{
|
|
226
|
+
return (ptr_a - other.ptr_a) / stride_a;
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
CUDA_CALLABLE reference operator[](difference_type n) const
|
|
230
|
+
{
|
|
231
|
+
return compute_value(n);
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
CUDA_CALLABLE bool operator==(const self_type &rhs) const
|
|
235
|
+
{
|
|
236
|
+
return (ptr_a == rhs.ptr_a);
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
CUDA_CALLABLE bool operator!=(const self_type &rhs) const
|
|
240
|
+
{
|
|
241
|
+
return (ptr_a != rhs.ptr_a);
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
private:
|
|
245
|
+
CUDA_CALLABLE ScalarT compute_value(difference_type n) const
|
|
246
|
+
{
|
|
247
|
+
ScalarT val(0);
|
|
248
|
+
const ElemT *a = ptr_a + n * stride_a;
|
|
249
|
+
const ElemT *b = ptr_b + n * stride_b;
|
|
250
|
+
for (int k = 0; k < type_length; ++k)
|
|
251
|
+
{
|
|
252
|
+
val += element_inner_product(a[k], b[k]);
|
|
253
|
+
}
|
|
254
|
+
return val;
|
|
255
|
+
}
|
|
256
|
+
};
|
|
257
|
+
|
|
258
|
+
template <typename ElemT, typename ScalarT>
|
|
259
|
+
void array_inner_device(const ElemT *ptr_a, const ElemT *ptr_b, ScalarT *ptr_out, int count, int byte_stride_a,
|
|
260
|
+
int byte_stride_b, int type_length)
|
|
261
|
+
{
|
|
262
|
+
assert((byte_stride_a % sizeof(ElemT)) == 0);
|
|
263
|
+
assert((byte_stride_b % sizeof(ElemT)) == 0);
|
|
264
|
+
const int stride_a = byte_stride_a / sizeof(ElemT);
|
|
265
|
+
const int stride_b = byte_stride_b / sizeof(ElemT);
|
|
266
|
+
|
|
267
|
+
ContextGuard guard(cuda_context_get_current());
|
|
268
|
+
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
269
|
+
|
|
270
|
+
cub_inner_product_iterator<ElemT, ScalarT> inner_iterator{ptr_a, ptr_b, stride_a, stride_b, type_length};
|
|
271
|
+
|
|
272
|
+
size_t buff_size = 0;
|
|
273
|
+
check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, inner_iterator, ptr_out, count, stream));
|
|
274
|
+
void* temp_buffer = alloc_device(WP_CURRENT_CONTEXT, buff_size);
|
|
275
|
+
|
|
276
|
+
check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, inner_iterator, ptr_out, count, stream));
|
|
277
|
+
|
|
278
|
+
free_device(WP_CURRENT_CONTEXT, temp_buffer);
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
template <typename T>
|
|
282
|
+
void array_inner_device_dispatch(const T *ptr_a, const T *ptr_b, T *ptr_out, int count, int byte_stride_a,
|
|
283
|
+
int byte_stride_b, int type_length)
|
|
284
|
+
{
|
|
285
|
+
using vec2 = wp::vec_t<2, T>;
|
|
286
|
+
using vec3 = wp::vec_t<3, T>;
|
|
287
|
+
using vec4 = wp::vec_t<4, T>;
|
|
288
|
+
|
|
289
|
+
// specialized calls for common vector types
|
|
290
|
+
|
|
291
|
+
if ((type_length % 4) == 0 && (byte_stride_a % sizeof(vec4)) == 0 && (byte_stride_b % sizeof(vec4)) == 0)
|
|
292
|
+
{
|
|
293
|
+
return array_inner_device(reinterpret_cast<const vec4 *>(ptr_a), reinterpret_cast<const vec4 *>(ptr_b), ptr_out,
|
|
294
|
+
count, byte_stride_a, byte_stride_b, type_length / 4);
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
if ((type_length % 3) == 0 && (byte_stride_a % sizeof(vec3)) == 0 && (byte_stride_b % sizeof(vec3)) == 0)
|
|
298
|
+
{
|
|
299
|
+
return array_inner_device(reinterpret_cast<const vec3 *>(ptr_a), reinterpret_cast<const vec3 *>(ptr_b), ptr_out,
|
|
300
|
+
count, byte_stride_a, byte_stride_b, type_length / 3);
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
if ((type_length % 2) == 0 && (byte_stride_a % sizeof(vec2)) == 0 && (byte_stride_b % sizeof(vec2)) == 0)
|
|
304
|
+
{
|
|
305
|
+
return array_inner_device(reinterpret_cast<const vec2 *>(ptr_a), reinterpret_cast<const vec2 *>(ptr_b), ptr_out,
|
|
306
|
+
count, byte_stride_a, byte_stride_b, type_length / 2);
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
return array_inner_device(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_length);
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
} // anonymous namespace
|
|
313
|
+
|
|
314
|
+
void array_inner_float_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
|
|
315
|
+
int type_len)
|
|
316
|
+
{
|
|
317
|
+
void *context = cuda_context_get_current();
|
|
318
|
+
|
|
319
|
+
const float *ptr_a = (const float *)(a);
|
|
320
|
+
const float *ptr_b = (const float *)(b);
|
|
321
|
+
float *ptr_out = (float *)(out);
|
|
322
|
+
|
|
323
|
+
array_inner_device_dispatch(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_len);
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
void array_inner_double_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
|
|
327
|
+
int type_len)
|
|
328
|
+
{
|
|
329
|
+
const double *ptr_a = (const double *)(a);
|
|
330
|
+
const double *ptr_b = (const double *)(b);
|
|
331
|
+
double *ptr_out = (double *)(out);
|
|
332
|
+
|
|
333
|
+
array_inner_device_dispatch(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_len);
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
void array_sum_float_device(uint64_t a, uint64_t out, int count, int byte_stride, int type_length)
|
|
337
|
+
{
|
|
338
|
+
const float *ptr_a = (const float *)(a);
|
|
339
|
+
float *ptr_out = (float *)(out);
|
|
340
|
+
array_sum_device_dispatch(ptr_a, ptr_out, count, byte_stride, type_length);
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
void array_sum_double_device(uint64_t a, uint64_t out, int count, int byte_stride, int type_length)
|
|
344
|
+
{
|
|
345
|
+
const double *ptr_a = (const double *)(a);
|
|
346
|
+
double *ptr_out = (double *)(out);
|
|
347
|
+
array_sum_device_dispatch(ptr_a, ptr_out, count, byte_stride, type_length);
|
|
348
|
+
}
|