cuda-cccl 0.1.3.2.0.dev438__cp311-cp311-manylinux_2_24_aarch64.whl → 0.3.1__cp311-cp311-manylinux_2_24_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cuda-cccl might be problematic. Click here for more details.
- cuda/cccl/cooperative/__init__.py +7 -1
- cuda/cccl/cooperative/experimental/__init__.py +21 -5
- cuda/cccl/headers/include/cub/agent/agent_adjacent_difference.cuh +2 -5
- cuda/cccl/headers/include/cub/agent/agent_batch_memcpy.cuh +2 -5
- cuda/cccl/headers/include/cub/agent/agent_for.cuh +2 -5
- cuda/cccl/headers/include/cub/agent/agent_merge.cuh +23 -21
- cuda/cccl/headers/include/cub/agent/agent_merge_sort.cuh +21 -3
- cuda/cccl/headers/include/cub/agent/agent_radix_sort_downsweep.cuh +25 -5
- cuda/cccl/headers/include/cub/agent/agent_radix_sort_histogram.cuh +2 -5
- cuda/cccl/headers/include/cub/agent/agent_radix_sort_onesweep.cuh +2 -5
- cuda/cccl/headers/include/cub/agent/agent_radix_sort_upsweep.cuh +2 -5
- cuda/cccl/headers/include/cub/agent/agent_rle.cuh +2 -5
- cuda/cccl/headers/include/cub/agent/agent_scan.cuh +5 -1
- cuda/cccl/headers/include/cub/agent/agent_scan_by_key.cuh +2 -5
- cuda/cccl/headers/include/cub/agent/agent_segmented_radix_sort.cuh +2 -5
- cuda/cccl/headers/include/cub/agent/agent_select_if.cuh +2 -5
- cuda/cccl/headers/include/cub/agent/agent_sub_warp_merge_sort.cuh +24 -19
- cuda/cccl/headers/include/cub/agent/agent_three_way_partition.cuh +2 -5
- cuda/cccl/headers/include/cub/agent/agent_unique_by_key.cuh +22 -5
- cuda/cccl/headers/include/cub/block/block_load_to_shared.cuh +432 -0
- cuda/cccl/headers/include/cub/block/block_radix_rank.cuh +3 -2
- cuda/cccl/headers/include/cub/block/block_radix_sort.cuh +4 -2
- cuda/cccl/headers/include/cub/detail/device_memory_resource.cuh +1 -0
- cuda/cccl/headers/include/cub/detail/mdspan_utils.cuh +18 -26
- cuda/cccl/headers/include/cub/device/device_copy.cuh +116 -27
- cuda/cccl/headers/include/cub/device/device_partition.cuh +5 -1
- cuda/cccl/headers/include/cub/device/device_segmented_reduce.cuh +158 -247
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_copy_mdspan.cuh +79 -0
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_merge.cuh +4 -4
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_radix_sort.cuh +2 -11
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce.cuh +8 -26
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce_deterministic.cuh +1 -6
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_reduce_nondeterministic.cuh +0 -1
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_segmented_sort.cuh +320 -262
- cuda/cccl/headers/include/cub/device/dispatch/kernels/reduce.cuh +10 -5
- cuda/cccl/headers/include/cub/device/dispatch/kernels/scan.cuh +2 -5
- cuda/cccl/headers/include/cub/device/dispatch/kernels/segmented_reduce.cuh +2 -5
- cuda/cccl/headers/include/cub/device/dispatch/kernels/segmented_sort.cuh +57 -10
- cuda/cccl/headers/include/cub/device/dispatch/kernels/transform.cuh +37 -13
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_adjacent_difference.cuh +2 -5
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_batch_memcpy.cuh +2 -5
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_for.cuh +2 -5
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_histogram.cuh +2 -5
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_merge.cuh +2 -5
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_merge_sort.cuh +8 -0
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_radix_sort.cuh +2 -5
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_reduce_by_key.cuh +2 -5
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_run_length_encode.cuh +2 -5
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_scan.cuh +2 -5
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_scan_by_key.cuh +2 -5
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_segmented_sort.cuh +204 -55
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_three_way_partition.cuh +2 -5
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_transform.cuh +55 -19
- cuda/cccl/headers/include/cub/device/dispatch/tuning/tuning_unique_by_key.cuh +10 -0
- cuda/cccl/headers/include/cub/util_device.cuh +51 -35
- cuda/cccl/headers/include/cub/warp/specializations/warp_reduce_shfl.cuh +3 -2
- cuda/cccl/headers/include/cub/warp/specializations/warp_reduce_smem.cuh +3 -2
- cuda/cccl/headers/include/cub/warp/specializations/warp_scan_shfl.cuh +2 -2
- cuda/cccl/headers/include/cuda/__algorithm/common.h +1 -1
- cuda/cccl/headers/include/cuda/__algorithm/copy.h +4 -4
- cuda/cccl/headers/include/cuda/__algorithm/fill.h +1 -1
- cuda/cccl/headers/include/cuda/__device/all_devices.h +47 -147
- cuda/cccl/headers/include/cuda/__device/arch_traits.h +51 -49
- cuda/cccl/headers/include/cuda/__device/attributes.h +177 -127
- cuda/cccl/headers/include/cuda/__device/device_ref.h +32 -51
- cuda/cccl/headers/include/cuda/__device/physical_device.h +120 -91
- cuda/cccl/headers/include/cuda/__driver/driver_api.h +330 -36
- cuda/cccl/headers/include/cuda/__event/event.h +8 -8
- cuda/cccl/headers/include/cuda/__event/event_ref.h +4 -5
- cuda/cccl/headers/include/cuda/__event/timed_event.h +4 -4
- cuda/cccl/headers/include/cuda/__fwd/devices.h +44 -0
- cuda/cccl/headers/include/cuda/__fwd/zip_iterator.h +9 -0
- cuda/cccl/headers/include/cuda/__iterator/transform_input_output_iterator.h +3 -3
- cuda/cccl/headers/include/cuda/__iterator/transform_iterator.h +3 -3
- cuda/cccl/headers/include/cuda/__iterator/transform_output_iterator.h +3 -3
- cuda/cccl/headers/include/cuda/__iterator/zip_common.h +158 -0
- cuda/cccl/headers/include/cuda/__iterator/zip_iterator.h +8 -120
- cuda/cccl/headers/include/cuda/__iterator/zip_transform_iterator.h +593 -0
- cuda/cccl/headers/include/cuda/__mdspan/host_device_accessor.h +14 -10
- cuda/cccl/headers/include/cuda/__runtime/ensure_current_context.h +4 -3
- cuda/cccl/headers/include/cuda/__runtime/types.h +1 -1
- cuda/cccl/headers/include/cuda/__stream/stream.h +2 -3
- cuda/cccl/headers/include/cuda/__stream/stream_ref.h +18 -12
- cuda/cccl/headers/include/cuda/__utility/__basic_any/virtual_tables.h +2 -2
- cuda/cccl/headers/include/cuda/__utility/basic_any.h +1 -1
- cuda/cccl/headers/include/cuda/algorithm +1 -1
- cuda/cccl/headers/include/cuda/devices +10 -0
- cuda/cccl/headers/include/cuda/iterator +1 -0
- cuda/cccl/headers/include/cuda/std/__bit/countl.h +8 -1
- cuda/cccl/headers/include/cuda/std/__bit/countr.h +2 -2
- cuda/cccl/headers/include/cuda/std/__bit/reference.h +11 -11
- cuda/cccl/headers/include/cuda/std/__cccl/cuda_capabilities.h +2 -2
- cuda/cccl/headers/include/cuda/std/__cccl/preprocessor.h +2 -0
- cuda/cccl/headers/include/cuda/std/__chrono/duration.h +16 -16
- cuda/cccl/headers/include/cuda/std/__chrono/steady_clock.h +5 -5
- cuda/cccl/headers/include/cuda/std/__chrono/system_clock.h +5 -5
- cuda/cccl/headers/include/cuda/std/__cmath/isnan.h +10 -5
- cuda/cccl/headers/include/cuda/std/__cmath/min_max.h +44 -17
- cuda/cccl/headers/include/cuda/std/__concepts/constructible.h +1 -1
- cuda/cccl/headers/include/cuda/std/__cuda/api_wrapper.h +12 -12
- cuda/cccl/headers/include/cuda/std/__exception/cuda_error.h +1 -8
- cuda/cccl/headers/include/cuda/std/__floating_point/cast.h +15 -12
- cuda/cccl/headers/include/cuda/std/__floating_point/cuda_fp_types.h +3 -0
- cuda/cccl/headers/include/cuda/std/__floating_point/fp.h +1 -1
- cuda/cccl/headers/include/cuda/std/__mdspan/mdspan.h +2 -1
- cuda/cccl/headers/include/cuda/std/__tuple_dir/make_tuple_types.h +23 -1
- cuda/cccl/headers/include/cuda/std/__tuple_dir/tuple_like.h +4 -0
- cuda/cccl/headers/include/cuda/std/__tuple_dir/tuple_like_ext.h +4 -0
- cuda/cccl/headers/include/cuda/std/__type_traits/promote.h +3 -2
- cuda/cccl/headers/include/cuda/std/string_view +12 -5
- cuda/cccl/headers/include/cuda/std/version +1 -4
- cuda/cccl/headers/include/thrust/detail/integer_math.h +3 -20
- cuda/cccl/headers/include/thrust/iterator/iterator_traits.h +11 -0
- cuda/cccl/headers/include/thrust/system/cuda/detail/copy.h +33 -0
- cuda/cccl/headers/include/thrust/system/cuda/detail/tabulate.h +8 -22
- cuda/cccl/headers/include/thrust/type_traits/unwrap_contiguous_iterator.h +15 -48
- cuda/cccl/parallel/experimental/__init__.py +21 -70
- cuda/compute/__init__.py +77 -0
- cuda/{cccl/parallel/experimental → compute}/_bindings.pyi +28 -0
- cuda/{cccl/parallel/experimental → compute}/_bindings_impl.pyx +141 -1
- cuda/{cccl/parallel/experimental → compute}/algorithms/__init__.py +4 -0
- cuda/{cccl/parallel/experimental → compute}/algorithms/_histogram.py +2 -2
- cuda/{cccl/parallel/experimental → compute}/algorithms/_merge_sort.py +2 -2
- cuda/{cccl/parallel/experimental → compute}/algorithms/_radix_sort.py +3 -3
- cuda/{cccl/parallel/experimental → compute}/algorithms/_reduce.py +2 -4
- cuda/{cccl/parallel/experimental → compute}/algorithms/_scan.py +4 -6
- cuda/{cccl/parallel/experimental → compute}/algorithms/_segmented_reduce.py +2 -2
- cuda/compute/algorithms/_three_way_partition.py +261 -0
- cuda/{cccl/parallel/experimental → compute}/algorithms/_transform.py +4 -4
- cuda/{cccl/parallel/experimental → compute}/algorithms/_unique_by_key.py +2 -2
- cuda/compute/cu12/_bindings_impl.cpython-311-aarch64-linux-gnu.so +0 -0
- cuda/{cccl/parallel/experimental → compute}/cu12/cccl/libcccl.c.parallel.so +0 -0
- cuda/compute/cu13/_bindings_impl.cpython-311-aarch64-linux-gnu.so +0 -0
- cuda/{cccl/parallel/experimental → compute}/cu13/cccl/libcccl.c.parallel.so +0 -0
- cuda/{cccl/parallel/experimental → compute}/iterators/_factories.py +8 -8
- cuda/{cccl/parallel/experimental → compute}/struct.py +2 -2
- cuda/coop/__init__.py +8 -0
- cuda/{cccl/cooperative/experimental → coop}/_nvrtc.py +3 -2
- cuda/{cccl/cooperative/experimental → coop}/_scan_op.py +3 -3
- cuda/{cccl/cooperative/experimental → coop}/_types.py +2 -2
- cuda/{cccl/cooperative/experimental → coop}/_typing.py +1 -1
- cuda/{cccl/cooperative/experimental → coop}/block/__init__.py +6 -6
- cuda/{cccl/cooperative/experimental → coop}/block/_block_exchange.py +4 -4
- cuda/{cccl/cooperative/experimental → coop}/block/_block_load_store.py +6 -6
- cuda/{cccl/cooperative/experimental → coop}/block/_block_merge_sort.py +4 -4
- cuda/{cccl/cooperative/experimental → coop}/block/_block_radix_sort.py +6 -6
- cuda/{cccl/cooperative/experimental → coop}/block/_block_reduce.py +6 -6
- cuda/{cccl/cooperative/experimental → coop}/block/_block_scan.py +7 -7
- cuda/coop/warp/__init__.py +9 -0
- cuda/{cccl/cooperative/experimental → coop}/warp/_warp_merge_sort.py +3 -3
- cuda/{cccl/cooperative/experimental → coop}/warp/_warp_reduce.py +6 -6
- cuda/{cccl/cooperative/experimental → coop}/warp/_warp_scan.py +4 -4
- {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/METADATA +1 -1
- {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/RECORD +171 -166
- cuda/cccl/cooperative/experimental/warp/__init__.py +0 -9
- cuda/cccl/headers/include/cub/device/dispatch/dispatch_advance_iterators.cuh +0 -111
- cuda/cccl/headers/include/cuda/std/__cuda/ensure_current_device.h +0 -72
- cuda/cccl/parallel/experimental/.gitignore +0 -4
- cuda/cccl/parallel/experimental/cu12/_bindings_impl.cpython-311-aarch64-linux-gnu.so +0 -0
- cuda/cccl/parallel/experimental/cu13/_bindings_impl.cpython-311-aarch64-linux-gnu.so +0 -0
- /cuda/{cccl/parallel/experimental → compute}/_bindings.py +0 -0
- /cuda/{cccl/parallel/experimental → compute}/_caching.py +0 -0
- /cuda/{cccl/parallel/experimental → compute}/_cccl_interop.py +0 -0
- /cuda/{cccl/parallel/experimental → compute}/_utils/__init__.py +0 -0
- /cuda/{cccl/parallel/experimental → compute}/_utils/protocols.py +0 -0
- /cuda/{cccl/parallel/experimental → compute}/_utils/temp_storage_buffer.py +0 -0
- /cuda/{cccl/parallel/experimental → compute}/cccl/.gitkeep +0 -0
- /cuda/{cccl/parallel/experimental → compute}/iterators/__init__.py +0 -0
- /cuda/{cccl/parallel/experimental → compute}/iterators/_iterators.py +0 -0
- /cuda/{cccl/parallel/experimental → compute}/iterators/_zip_iterator.py +0 -0
- /cuda/{cccl/parallel/experimental → compute}/numba_utils.py +0 -0
- /cuda/{cccl/parallel/experimental → compute}/op.py +0 -0
- /cuda/{cccl/parallel/experimental → compute}/typing.py +0 -0
- /cuda/{cccl/cooperative/experimental → coop}/_caching.py +0 -0
- /cuda/{cccl/cooperative/experimental → coop}/_common.py +0 -0
- {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/WHEEL +0 -0
- {cuda_cccl-0.1.3.2.0.dev438.dist-info → cuda_cccl-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -43,10 +43,120 @@
|
|
|
43
43
|
|
|
44
44
|
CUB_NAMESPACE_BEGIN
|
|
45
45
|
|
|
46
|
-
namespace detail
|
|
46
|
+
namespace detail::segmented_sort
|
|
47
47
|
{
|
|
48
|
-
|
|
48
|
+
|
|
49
|
+
template <typename PolicyT, typename = void>
|
|
50
|
+
struct SegmentedSortPolicyWrapper : PolicyT
|
|
51
|
+
{
|
|
52
|
+
CUB_RUNTIME_FUNCTION SegmentedSortPolicyWrapper(PolicyT base)
|
|
53
|
+
: PolicyT(base)
|
|
54
|
+
{}
|
|
55
|
+
};
|
|
56
|
+
|
|
57
|
+
template <typename StaticPolicyT>
|
|
58
|
+
struct SegmentedSortPolicyWrapper<StaticPolicyT,
|
|
59
|
+
_CUDA_VSTD::void_t<typename StaticPolicyT::LargeSegmentPolicy,
|
|
60
|
+
typename StaticPolicyT::SmallSegmentPolicy,
|
|
61
|
+
typename StaticPolicyT::MediumSegmentPolicy>> : StaticPolicyT
|
|
49
62
|
{
|
|
63
|
+
CUB_RUNTIME_FUNCTION SegmentedSortPolicyWrapper(StaticPolicyT base)
|
|
64
|
+
: StaticPolicyT(base)
|
|
65
|
+
{}
|
|
66
|
+
|
|
67
|
+
CUB_RUNTIME_FUNCTION static constexpr auto LargeSegment()
|
|
68
|
+
{
|
|
69
|
+
return cub::detail::MakePolicyWrapper(typename StaticPolicyT::LargeSegmentPolicy());
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
CUB_RUNTIME_FUNCTION static constexpr auto SmallSegment()
|
|
73
|
+
{
|
|
74
|
+
return cub::detail::MakePolicyWrapper(typename StaticPolicyT::SmallSegmentPolicy());
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
CUB_RUNTIME_FUNCTION static constexpr auto MediumSegment()
|
|
78
|
+
{
|
|
79
|
+
return cub::detail::MakePolicyWrapper(typename StaticPolicyT::MediumSegmentPolicy());
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
CUB_RUNTIME_FUNCTION static constexpr int PartitioningThreshold()
|
|
83
|
+
{
|
|
84
|
+
return StaticPolicyT::PARTITIONING_THRESHOLD;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
CUB_RUNTIME_FUNCTION static constexpr int LargeSegmentRadixBits()
|
|
88
|
+
{
|
|
89
|
+
return StaticPolicyT::LargeSegmentPolicy::RADIX_BITS;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
CUB_RUNTIME_FUNCTION static constexpr int SegmentsPerSmallBlock()
|
|
93
|
+
{
|
|
94
|
+
return StaticPolicyT::SmallSegmentPolicy::SEGMENTS_PER_BLOCK;
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
CUB_RUNTIME_FUNCTION static constexpr int SegmentsPerMediumBlock()
|
|
98
|
+
{
|
|
99
|
+
return StaticPolicyT::MediumSegmentPolicy::SEGMENTS_PER_BLOCK;
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
CUB_RUNTIME_FUNCTION static constexpr int SmallPolicyItemsPerTile()
|
|
103
|
+
{
|
|
104
|
+
return StaticPolicyT::SmallSegmentPolicy::ITEMS_PER_TILE;
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
CUB_RUNTIME_FUNCTION static constexpr int MediumPolicyItemsPerTile()
|
|
108
|
+
{
|
|
109
|
+
return StaticPolicyT::MediumSegmentPolicy::ITEMS_PER_TILE;
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
CUB_RUNTIME_FUNCTION static constexpr CacheLoadModifier LargeSegmentLoadModifier()
|
|
113
|
+
{
|
|
114
|
+
return StaticPolicyT::LargeSegmentPolicy::LOAD_MODIFIER;
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
CUB_RUNTIME_FUNCTION static constexpr BlockLoadAlgorithm LargeSegmentLoadAlgorithm()
|
|
118
|
+
{
|
|
119
|
+
return StaticPolicyT::LargeSegmentPolicy::LOAD_ALGORITHM;
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
CUB_RUNTIME_FUNCTION static constexpr WarpLoadAlgorithm MediumSegmentLoadAlgorithm()
|
|
123
|
+
{
|
|
124
|
+
return StaticPolicyT::MediumSegmentPolicy::LOAD_ALGORITHM;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
CUB_RUNTIME_FUNCTION static constexpr WarpLoadAlgorithm SmallSegmentLoadAlgorithm()
|
|
128
|
+
{
|
|
129
|
+
return StaticPolicyT::SmallSegmentPolicy::LOAD_ALGORITHM;
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
CUB_RUNTIME_FUNCTION static constexpr WarpStoreAlgorithm MediumSegmentStoreAlgorithm()
|
|
133
|
+
{
|
|
134
|
+
return StaticPolicyT::MediumSegmentPolicy::STORE_ALGORITHM;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
CUB_RUNTIME_FUNCTION static constexpr WarpStoreAlgorithm SmallSegmentStoreAlgorithm()
|
|
138
|
+
{
|
|
139
|
+
return StaticPolicyT::SmallSegmentPolicy::STORE_ALGORITHM;
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
#if defined(CUB_ENABLE_POLICY_PTX_JSON)
|
|
143
|
+
_CCCL_DEVICE static constexpr auto EncodedPolicy()
|
|
144
|
+
{
|
|
145
|
+
using namespace ptx_json;
|
|
146
|
+
return object<key<"LargeSegmentPolicy">() = LargeSegment().EncodedPolicy(),
|
|
147
|
+
key<"SmallSegmentPolicy">() = SmallSegment().EncodedPolicy(),
|
|
148
|
+
key<"MediumSegmentPolicy">() = MediumSegment().EncodedPolicy(),
|
|
149
|
+
key<"PartitioningThreshold">() = value<StaticPolicyT::PARTITIONING_THRESHOLD>()>();
|
|
150
|
+
}
|
|
151
|
+
#endif
|
|
152
|
+
};
|
|
153
|
+
|
|
154
|
+
template <typename PolicyT>
|
|
155
|
+
CUB_RUNTIME_FUNCTION SegmentedSortPolicyWrapper<PolicyT> MakeSegmentedSortPolicyWrapper(PolicyT policy)
|
|
156
|
+
{
|
|
157
|
+
return SegmentedSortPolicyWrapper<PolicyT>{policy};
|
|
158
|
+
}
|
|
159
|
+
|
|
50
160
|
template <typename KeyT, typename ValueT>
|
|
51
161
|
struct policy_hub
|
|
52
162
|
{
|
|
@@ -71,12 +181,19 @@ struct policy_hub
|
|
|
71
181
|
|
|
72
182
|
static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(7);
|
|
73
183
|
static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(7);
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
184
|
+
|
|
185
|
+
using SmallSegmentPolicy =
|
|
186
|
+
AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
|
|
187
|
+
4 /* Threads per segment */,
|
|
188
|
+
ITEMS_PER_SMALL_THREAD,
|
|
189
|
+
WARP_LOAD_DIRECT,
|
|
190
|
+
LOAD_DEFAULT>;
|
|
191
|
+
using MediumSegmentPolicy =
|
|
192
|
+
AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
|
|
193
|
+
32 /* Threads per segment */,
|
|
194
|
+
ITEMS_PER_MEDIUM_THREAD,
|
|
195
|
+
WARP_LOAD_DIRECT,
|
|
196
|
+
LOAD_DEFAULT>;
|
|
80
197
|
};
|
|
81
198
|
|
|
82
199
|
struct Policy600 : ChainedPolicy<600, Policy600, Policy500>
|
|
@@ -97,12 +214,19 @@ struct policy_hub
|
|
|
97
214
|
|
|
98
215
|
static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(9);
|
|
99
216
|
static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(9);
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
217
|
+
|
|
218
|
+
using SmallSegmentPolicy =
|
|
219
|
+
AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
|
|
220
|
+
4 /* Threads per segment */,
|
|
221
|
+
ITEMS_PER_SMALL_THREAD,
|
|
222
|
+
WARP_LOAD_DIRECT,
|
|
223
|
+
LOAD_DEFAULT>;
|
|
224
|
+
using MediumSegmentPolicy =
|
|
225
|
+
AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
|
|
226
|
+
32 /* Threads per segment */,
|
|
227
|
+
ITEMS_PER_MEDIUM_THREAD,
|
|
228
|
+
WARP_LOAD_DIRECT,
|
|
229
|
+
LOAD_DEFAULT>;
|
|
106
230
|
};
|
|
107
231
|
|
|
108
232
|
struct Policy610 : ChainedPolicy<610, Policy610, Policy600>
|
|
@@ -123,12 +247,19 @@ struct policy_hub
|
|
|
123
247
|
|
|
124
248
|
static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(9);
|
|
125
249
|
static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(9);
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
250
|
+
|
|
251
|
+
using SmallSegmentPolicy =
|
|
252
|
+
AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
|
|
253
|
+
4 /* Threads per segment */,
|
|
254
|
+
ITEMS_PER_SMALL_THREAD,
|
|
255
|
+
WARP_LOAD_DIRECT,
|
|
256
|
+
LOAD_DEFAULT>;
|
|
257
|
+
using MediumSegmentPolicy =
|
|
258
|
+
AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
|
|
259
|
+
32 /* Threads per segment */,
|
|
260
|
+
ITEMS_PER_MEDIUM_THREAD,
|
|
261
|
+
WARP_LOAD_DIRECT,
|
|
262
|
+
LOAD_DEFAULT>;
|
|
132
263
|
};
|
|
133
264
|
|
|
134
265
|
struct Policy620 : ChainedPolicy<620, Policy620, Policy610>
|
|
@@ -149,12 +280,19 @@ struct policy_hub
|
|
|
149
280
|
|
|
150
281
|
static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(9);
|
|
151
282
|
static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(9);
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
283
|
+
|
|
284
|
+
using SmallSegmentPolicy =
|
|
285
|
+
AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
|
|
286
|
+
4 /* Threads per segment */,
|
|
287
|
+
ITEMS_PER_SMALL_THREAD,
|
|
288
|
+
WARP_LOAD_DIRECT,
|
|
289
|
+
LOAD_DEFAULT>;
|
|
290
|
+
using MediumSegmentPolicy =
|
|
291
|
+
AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
|
|
292
|
+
32 /* Threads per segment */,
|
|
293
|
+
ITEMS_PER_MEDIUM_THREAD,
|
|
294
|
+
WARP_LOAD_DIRECT,
|
|
295
|
+
LOAD_DEFAULT>;
|
|
158
296
|
};
|
|
159
297
|
|
|
160
298
|
struct Policy700 : ChainedPolicy<700, Policy700, Policy620>
|
|
@@ -175,15 +313,19 @@ struct policy_hub
|
|
|
175
313
|
|
|
176
314
|
static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(7);
|
|
177
315
|
static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(KEYS_ONLY ? 11 : 7);
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
316
|
+
|
|
317
|
+
using SmallSegmentPolicy =
|
|
318
|
+
AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
|
|
319
|
+
KEYS_ONLY ? 4 : 8 /* Threads per segment */,
|
|
320
|
+
ITEMS_PER_SMALL_THREAD,
|
|
321
|
+
WARP_LOAD_DIRECT,
|
|
322
|
+
LOAD_DEFAULT>;
|
|
323
|
+
using MediumSegmentPolicy =
|
|
324
|
+
AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
|
|
325
|
+
32 /* Threads per segment */,
|
|
326
|
+
ITEMS_PER_MEDIUM_THREAD,
|
|
327
|
+
WARP_LOAD_DIRECT,
|
|
328
|
+
LOAD_DEFAULT>;
|
|
187
329
|
};
|
|
188
330
|
|
|
189
331
|
struct Policy800 : ChainedPolicy<800, Policy800, Policy700>
|
|
@@ -202,15 +344,19 @@ struct policy_hub
|
|
|
202
344
|
|
|
203
345
|
static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(9);
|
|
204
346
|
static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(KEYS_ONLY ? 7 : 11);
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
347
|
+
|
|
348
|
+
using SmallSegmentPolicy =
|
|
349
|
+
AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
|
|
350
|
+
KEYS_ONLY ? 4 : 2 /* Threads per segment */,
|
|
351
|
+
ITEMS_PER_SMALL_THREAD,
|
|
352
|
+
WARP_LOAD_TRANSPOSE,
|
|
353
|
+
LOAD_DEFAULT>;
|
|
354
|
+
using MediumSegmentPolicy =
|
|
355
|
+
AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
|
|
356
|
+
32 /* Threads per segment */,
|
|
357
|
+
ITEMS_PER_MEDIUM_THREAD,
|
|
358
|
+
WARP_LOAD_TRANSPOSE,
|
|
359
|
+
LOAD_DEFAULT>;
|
|
214
360
|
};
|
|
215
361
|
|
|
216
362
|
struct Policy860 : ChainedPolicy<860, Policy860, Policy800>
|
|
@@ -230,20 +376,23 @@ struct policy_hub
|
|
|
230
376
|
static constexpr bool LARGE_ITEMS = sizeof(DominantT) > 4;
|
|
231
377
|
static constexpr int ITEMS_PER_SMALL_THREAD = Nominal4BItemsToItems<DominantT>(LARGE_ITEMS ? 7 : 9);
|
|
232
378
|
static constexpr int ITEMS_PER_MEDIUM_THREAD = Nominal4BItemsToItems<DominantT>(LARGE_ITEMS ? 9 : 7);
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
379
|
+
|
|
380
|
+
using SmallSegmentPolicy =
|
|
381
|
+
AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
|
|
382
|
+
LARGE_ITEMS ? 8 : 2 /* Threads per segment */,
|
|
383
|
+
ITEMS_PER_SMALL_THREAD,
|
|
384
|
+
WARP_LOAD_TRANSPOSE,
|
|
385
|
+
LOAD_LDG>;
|
|
386
|
+
using MediumSegmentPolicy =
|
|
387
|
+
AgentSubWarpMergeSortPolicy<BLOCK_THREADS,
|
|
388
|
+
16 /* Threads per segment */,
|
|
389
|
+
ITEMS_PER_MEDIUM_THREAD,
|
|
390
|
+
WARP_LOAD_TRANSPOSE,
|
|
391
|
+
LOAD_LDG>;
|
|
242
392
|
};
|
|
243
393
|
|
|
244
394
|
using MaxPolicy = Policy860;
|
|
245
395
|
};
|
|
246
|
-
} // namespace segmented_sort
|
|
247
|
-
} // namespace detail
|
|
396
|
+
} // namespace detail::segmented_sort
|
|
248
397
|
|
|
249
398
|
CUB_NAMESPACE_END
|
|
@@ -47,9 +47,7 @@
|
|
|
47
47
|
|
|
48
48
|
CUB_NAMESPACE_BEGIN
|
|
49
49
|
|
|
50
|
-
namespace detail
|
|
51
|
-
{
|
|
52
|
-
namespace three_way_partition
|
|
50
|
+
namespace detail::three_way_partition
|
|
53
51
|
{
|
|
54
52
|
|
|
55
53
|
template <typename PolicyT, typename = void>
|
|
@@ -437,7 +435,6 @@ struct policy_hub
|
|
|
437
435
|
|
|
438
436
|
using MaxPolicy = Policy1000;
|
|
439
437
|
};
|
|
440
|
-
} // namespace three_way_partition
|
|
441
|
-
} // namespace detail
|
|
438
|
+
} // namespace detail::three_way_partition
|
|
442
439
|
|
|
443
440
|
CUB_NAMESPACE_END
|
|
@@ -282,21 +282,45 @@ _CCCL_HOST_DEVICE constexpr int arch_to_min_bytes_in_flight(int sm_arch)
|
|
|
282
282
|
return 12 * 1024; // V100 and below
|
|
283
283
|
}
|
|
284
284
|
|
|
285
|
-
template <typename
|
|
286
|
-
_CCCL_HOST_DEVICE constexpr bool
|
|
285
|
+
template <typename H, typename... Ts>
|
|
286
|
+
_CCCL_HOST_DEVICE constexpr bool all_nonzero_equal(H head, Ts... values)
|
|
287
287
|
{
|
|
288
|
-
|
|
288
|
+
size_t first = 0;
|
|
289
|
+
for (size_t v : ::cuda::std::array<H, 1 + sizeof...(Ts)>{head, values...})
|
|
290
|
+
{
|
|
291
|
+
if (v == 0)
|
|
292
|
+
{
|
|
293
|
+
continue;
|
|
294
|
+
}
|
|
295
|
+
if (first == 0)
|
|
296
|
+
{
|
|
297
|
+
first = v;
|
|
298
|
+
}
|
|
299
|
+
else if (v != first)
|
|
300
|
+
{
|
|
301
|
+
return false;
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
return true;
|
|
289
305
|
}
|
|
290
306
|
|
|
291
|
-
_CCCL_HOST_DEVICE constexpr bool
|
|
307
|
+
_CCCL_HOST_DEVICE constexpr bool all_nonzero_equal()
|
|
292
308
|
{
|
|
293
309
|
return true;
|
|
294
310
|
}
|
|
295
311
|
|
|
296
|
-
template <typename
|
|
297
|
-
_CCCL_HOST_DEVICE constexpr auto
|
|
312
|
+
template <typename H, typename... Ts>
|
|
313
|
+
_CCCL_HOST_DEVICE constexpr auto first_nonzero_value(H head, Ts... values)
|
|
298
314
|
{
|
|
299
|
-
|
|
315
|
+
for (auto v : ::cuda::std::array<H, 1 + sizeof...(Ts)>{head, values...})
|
|
316
|
+
{
|
|
317
|
+
if (v != 0)
|
|
318
|
+
{
|
|
319
|
+
return v;
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
// we only reach here when all input are not contiguous and the output has a void value type
|
|
323
|
+
return H{1};
|
|
300
324
|
}
|
|
301
325
|
|
|
302
326
|
template <typename T>
|
|
@@ -336,25 +360,36 @@ struct policy_hub<RequiresStableAddress,
|
|
|
336
360
|
(THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn> && ...);
|
|
337
361
|
static constexpr bool all_input_values_trivially_reloc =
|
|
338
362
|
(THRUST_NS_QUALIFIER::is_trivially_relocatable_v<it_value_t<RandomAccessIteratorsIn>> && ...);
|
|
339
|
-
static constexpr bool
|
|
363
|
+
static constexpr bool can_memcpy_all_inputs = all_inputs_contiguous && all_input_values_trivially_reloc;
|
|
364
|
+
// the vectorized kernel supports mixing contiguous and non-contiguous iterators
|
|
365
|
+
static constexpr bool can_memcpy_contiguous_inputs =
|
|
366
|
+
((!THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>
|
|
367
|
+
|| THRUST_NS_QUALIFIER::is_trivially_relocatable_v<it_value_t<RandomAccessIteratorsIn>>)
|
|
368
|
+
&& ...);
|
|
340
369
|
|
|
341
370
|
// for vectorized policy:
|
|
342
|
-
static constexpr bool
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
static constexpr int
|
|
346
|
-
|
|
371
|
+
static constexpr bool all_contiguous_input_values_same_size = all_nonzero_equal(
|
|
372
|
+
(sizeof(it_value_t<RandomAccessIteratorsIn>)
|
|
373
|
+
* THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>) ...);
|
|
374
|
+
static constexpr int load_store_word_size = 8; // TODO(bgruber): make this 16, and 32 on Blackwell+
|
|
375
|
+
// find the value type size of the first contiguous iterator. if there are no inputs, we take the size of the output
|
|
376
|
+
// value type
|
|
377
|
+
static constexpr int contiguous_value_type_size = first_nonzero_value(
|
|
378
|
+
(int{sizeof(it_value_t<RandomAccessIteratorsIn>)}
|
|
379
|
+
* THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>) ...,
|
|
380
|
+
int{size_of<it_value_t<RandomAccessIteratorOut>>});
|
|
347
381
|
static constexpr bool value_type_divides_load_store_size =
|
|
348
|
-
load_store_word_size %
|
|
382
|
+
load_store_word_size % contiguous_value_type_size == 0; // implicitly checks that value_type_size <=
|
|
383
|
+
// load_store_word_size
|
|
349
384
|
static constexpr int target_bytes_per_thread =
|
|
350
385
|
no_input_streams ? 16 /* by experiment on RTX 5090 */ : 32 /* guestimate by gevtushenko for loading */;
|
|
351
386
|
static constexpr int items_per_thread_vec =
|
|
352
|
-
::cuda::round_up(target_bytes_per_thread, load_store_word_size) /
|
|
387
|
+
::cuda::round_up(target_bytes_per_thread, load_store_word_size) / contiguous_value_type_size;
|
|
353
388
|
using default_vectorized_policy_t = vectorized_policy_t<256, items_per_thread_vec, load_store_word_size>;
|
|
354
389
|
|
|
355
390
|
static constexpr bool fallback_to_prefetch =
|
|
356
|
-
RequiresStableAddress || !
|
|
357
|
-
|| !DenseOutput;
|
|
391
|
+
RequiresStableAddress || !can_memcpy_contiguous_inputs || !all_contiguous_input_values_same_size
|
|
392
|
+
|| !value_type_divides_load_store_size || !DenseOutput;
|
|
358
393
|
|
|
359
394
|
// TODO(bgruber): consider a separate kernel for just filling
|
|
360
395
|
|
|
@@ -380,7 +415,7 @@ struct policy_hub<RequiresStableAddress,
|
|
|
380
415
|
block_threads* async_policy::min_items_per_thread,
|
|
381
416
|
ldgsts_size_and_align)
|
|
382
417
|
> int{max_smem_per_block};
|
|
383
|
-
static constexpr bool fallback_to_vectorized = exhaust_smem || no_input_streams;
|
|
418
|
+
static constexpr bool fallback_to_vectorized = exhaust_smem || no_input_streams || !can_memcpy_all_inputs;
|
|
384
419
|
|
|
385
420
|
public:
|
|
386
421
|
static constexpr int min_bif = arch_to_min_bytes_in_flight(800);
|
|
@@ -421,7 +456,8 @@ struct policy_hub<RequiresStableAddress,
|
|
|
421
456
|
(((int{sizeof(it_value_t<RandomAccessIteratorsIn>)} * AsyncBlockSize) % max_alignment == 0) && ...);
|
|
422
457
|
static constexpr bool enough_threads_for_peeling = AsyncBlockSize >= alignment; // head and tail bytes
|
|
423
458
|
static constexpr bool fallback_to_vectorized =
|
|
424
|
-
exhaust_smem || !tile_sizes_retain_alignment || !enough_threads_for_peeling || no_input_streams
|
|
459
|
+
exhaust_smem || !tile_sizes_retain_alignment || !enough_threads_for_peeling || no_input_streams
|
|
460
|
+
|| !can_memcpy_all_inputs;
|
|
425
461
|
|
|
426
462
|
public:
|
|
427
463
|
static constexpr int min_bif = arch_to_min_bytes_in_flight(PtxVersion);
|
|
@@ -788,6 +788,16 @@ struct UniqueByKeyPolicyWrapper<StaticPolicyT,
|
|
|
788
788
|
{
|
|
789
789
|
return cub::detail::MakePolicyWrapper(typename StaticPolicyT::UniqueByKeyPolicyT());
|
|
790
790
|
}
|
|
791
|
+
|
|
792
|
+
#if defined(CUB_ENABLE_POLICY_PTX_JSON)
|
|
793
|
+
_CCCL_DEVICE static constexpr auto EncodedPolicy()
|
|
794
|
+
{
|
|
795
|
+
using namespace ptx_json;
|
|
796
|
+
return object<key<"UniqueByKeyPolicyT">() = UniqueByKey().EncodedPolicy(),
|
|
797
|
+
key<"DelayConstructor">() =
|
|
798
|
+
StaticPolicyT::UniqueByKeyPolicyT::detail::delay_constructor_t::EncodedConstructor()>();
|
|
799
|
+
}
|
|
800
|
+
#endif
|
|
791
801
|
};
|
|
792
802
|
|
|
793
803
|
template <typename PolicyT>
|
|
@@ -47,7 +47,6 @@
|
|
|
47
47
|
// for backward compatibility
|
|
48
48
|
#include <cub/util_temporary_storage.cuh>
|
|
49
49
|
|
|
50
|
-
#include <cuda/std/__cuda/ensure_current_device.h> // IWYU pragma: export
|
|
51
50
|
#include <cuda/std/__type_traits/conditional.h>
|
|
52
51
|
#include <cuda/std/__utility/forward.h>
|
|
53
52
|
#include <cuda/std/array>
|
|
@@ -104,7 +103,34 @@ CUB_RUNTIME_FUNCTION inline int CurrentDevice()
|
|
|
104
103
|
|
|
105
104
|
//! @brief RAII helper which saves the current device and switches to the specified device on construction and switches
|
|
106
105
|
//! to the saved device on destruction.
|
|
107
|
-
|
|
106
|
+
class SwitchDevice
|
|
107
|
+
{
|
|
108
|
+
int target_device_;
|
|
109
|
+
int original_device_;
|
|
110
|
+
|
|
111
|
+
public:
|
|
112
|
+
//! @brief Queries the current device and if that is different than @p target_device sets the current device to
|
|
113
|
+
//! @p target_device
|
|
114
|
+
SwitchDevice(const int target_device)
|
|
115
|
+
: target_device_(target_device)
|
|
116
|
+
{
|
|
117
|
+
CubDebug(cudaGetDevice(&original_device_));
|
|
118
|
+
if (original_device_ != target_device_)
|
|
119
|
+
{
|
|
120
|
+
CubDebug(cudaSetDevice(target_device_));
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
//! @brief If the @p original_device was not equal to @p target_device sets the current device back to
|
|
125
|
+
//! @p original_device
|
|
126
|
+
~SwitchDevice()
|
|
127
|
+
{
|
|
128
|
+
if (original_device_ != target_device_)
|
|
129
|
+
{
|
|
130
|
+
CubDebug(cudaSetDevice(original_device_));
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
};
|
|
108
134
|
|
|
109
135
|
# endif // _CCCL_DOXYGEN_INVOKED
|
|
110
136
|
|
|
@@ -684,16 +710,31 @@ struct KernelConfig
|
|
|
684
710
|
return launcher_factory.MaxSmOccupancy(sm_occupancy, kernel_ptr, block_threads);
|
|
685
711
|
}
|
|
686
712
|
};
|
|
687
|
-
|
|
688
713
|
} // namespace detail
|
|
689
714
|
#endif // !_CCCL_COMPILER(NVRTC)
|
|
690
715
|
|
|
716
|
+
namespace detail
|
|
717
|
+
{
|
|
718
|
+
template <typename T>
|
|
719
|
+
struct get_active_policy
|
|
720
|
+
{
|
|
721
|
+
using type = typename T::ActivePolicy;
|
|
722
|
+
};
|
|
723
|
+
} // namespace detail
|
|
724
|
+
|
|
691
725
|
/// Helper for dispatching into a policy chain
|
|
692
726
|
template <int PolicyPtxVersion, typename PolicyT, typename PrevPolicyT>
|
|
693
727
|
struct ChainedPolicy
|
|
694
728
|
{
|
|
729
|
+
private:
|
|
730
|
+
static constexpr bool have_previous_policy = !::cuda::std::is_same_v<PolicyT, PrevPolicyT>;
|
|
731
|
+
|
|
732
|
+
public:
|
|
695
733
|
/// The policy for the active compiler pass
|
|
696
|
-
using ActivePolicy =
|
|
734
|
+
using ActivePolicy =
|
|
735
|
+
typename ::cuda::std::_If<(CUB_PTX_ARCH < PolicyPtxVersion && have_previous_policy),
|
|
736
|
+
detail::get_active_policy<PrevPolicyT>,
|
|
737
|
+
::cuda::std::type_identity<PolicyT>>::type;
|
|
697
738
|
|
|
698
739
|
#if !_CCCL_COMPILER(NVRTC)
|
|
699
740
|
/// Specializes and dispatches op in accordance to the first policy in the chain of adequate PTX version
|
|
@@ -708,9 +749,12 @@ struct ChainedPolicy
|
|
|
708
749
|
# elif defined(NV_TARGET_SM_INTEGER_LIST)
|
|
709
750
|
return runtime_to_compiletime<10, NV_TARGET_SM_INTEGER_LIST>(device_ptx_version, op);
|
|
710
751
|
# else
|
|
711
|
-
if (
|
|
752
|
+
if constexpr (have_previous_policy)
|
|
712
753
|
{
|
|
713
|
-
|
|
754
|
+
if (device_ptx_version < PolicyPtxVersion)
|
|
755
|
+
{
|
|
756
|
+
return PrevPolicyT::Invoke(device_ptx_version, op);
|
|
757
|
+
}
|
|
714
758
|
}
|
|
715
759
|
return op.template Invoke<PolicyT>();
|
|
716
760
|
# endif
|
|
@@ -738,7 +782,7 @@ private:
|
|
|
738
782
|
template <int DevicePtxVersion, typename FunctorT>
|
|
739
783
|
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT& op)
|
|
740
784
|
{
|
|
741
|
-
if constexpr (DevicePtxVersion < PolicyPtxVersion)
|
|
785
|
+
if constexpr (DevicePtxVersion < PolicyPtxVersion && have_previous_policy)
|
|
742
786
|
{
|
|
743
787
|
return PrevPolicyT::template invoke_static<DevicePtxVersion>(op);
|
|
744
788
|
}
|
|
@@ -749,34 +793,6 @@ private:
|
|
|
749
793
|
}
|
|
750
794
|
#endif // !_CCCL_COMPILER(NVRTC)
|
|
751
795
|
};
|
|
752
|
-
|
|
753
|
-
/// Helper for dispatching into a policy chain (end-of-chain specialization)
|
|
754
|
-
template <int PolicyPtxVersion, typename PolicyT>
|
|
755
|
-
struct ChainedPolicy<PolicyPtxVersion, PolicyT, PolicyT>
|
|
756
|
-
{
|
|
757
|
-
template <int, typename, typename>
|
|
758
|
-
friend struct ChainedPolicy; // befriend primary template, so it can call invoke_static
|
|
759
|
-
|
|
760
|
-
/// The policy for the active compiler pass
|
|
761
|
-
using ActivePolicy = PolicyT;
|
|
762
|
-
|
|
763
|
-
#if !_CCCL_COMPILER(NVRTC)
|
|
764
|
-
/// Specializes and dispatches op in accordance to the first policy in the chain of adequate PTX version
|
|
765
|
-
template <typename FunctorT>
|
|
766
|
-
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Invoke(int /*ptx_version*/, FunctorT& op)
|
|
767
|
-
{
|
|
768
|
-
return op.template Invoke<PolicyT>();
|
|
769
|
-
}
|
|
770
|
-
|
|
771
|
-
private:
|
|
772
|
-
template <int, typename FunctorT>
|
|
773
|
-
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT& op)
|
|
774
|
-
{
|
|
775
|
-
return op.template Invoke<PolicyT>();
|
|
776
|
-
}
|
|
777
|
-
#endif // !_CCCL_COMPILER(NVRTC)
|
|
778
|
-
};
|
|
779
|
-
|
|
780
796
|
CUB_NAMESPACE_END
|
|
781
797
|
|
|
782
798
|
#if _CCCL_HAS_CUDA_COMPILER() && !_CCCL_COMPILER(NVRTC)
|
|
@@ -51,6 +51,7 @@
|
|
|
51
51
|
#include <cuda/__functional/maximum.h>
|
|
52
52
|
#include <cuda/__functional/minimum.h>
|
|
53
53
|
#include <cuda/__ptx/instructions/get_sreg.h>
|
|
54
|
+
#include <cuda/std/__bit/countr.h>
|
|
54
55
|
#include <cuda/std/__functional/operations.h>
|
|
55
56
|
#include <cuda/std/__type_traits/enable_if.h>
|
|
56
57
|
#include <cuda/std/__type_traits/integral_constant.h>
|
|
@@ -701,7 +702,7 @@ struct WarpReduceShfl
|
|
|
701
702
|
_CCCL_DEVICE _CCCL_FORCEINLINE T SegmentedReduce(T input, FlagT flag, ReductionOp reduction_op)
|
|
702
703
|
{
|
|
703
704
|
// Get the start flags for each thread in the warp.
|
|
704
|
-
|
|
705
|
+
unsigned warp_flags = __ballot_sync(member_mask, flag);
|
|
705
706
|
|
|
706
707
|
// Convert to tail-segmented
|
|
707
708
|
if (HEAD_SEGMENTED)
|
|
@@ -722,7 +723,7 @@ struct WarpReduceShfl
|
|
|
722
723
|
warp_flags |= 1u << (LOGICAL_WARP_THREADS - 1);
|
|
723
724
|
|
|
724
725
|
// Find the next set flag
|
|
725
|
-
int last_lane =
|
|
726
|
+
int last_lane = ::cuda::std::countr_zero(warp_flags);
|
|
726
727
|
|
|
727
728
|
T output = input;
|
|
728
729
|
// Template-iterate reduction steps
|
|
@@ -49,6 +49,7 @@
|
|
|
49
49
|
#include <cub/util_type.cuh>
|
|
50
50
|
|
|
51
51
|
#include <cuda/__ptx/instructions/get_sreg.h>
|
|
52
|
+
#include <cuda/std/__bit/countr.h>
|
|
52
53
|
#include <cuda/std/__type_traits/integral_constant.h>
|
|
53
54
|
|
|
54
55
|
CUB_NAMESPACE_BEGIN
|
|
@@ -215,7 +216,7 @@ struct WarpReduceSmem
|
|
|
215
216
|
SegmentedReduce(T input, FlagT flag, ReductionOp reduction_op, ::cuda::std::true_type /*has_ballot*/)
|
|
216
217
|
{
|
|
217
218
|
// Get the start flags for each thread in the warp.
|
|
218
|
-
|
|
219
|
+
unsigned warp_flags = __ballot_sync(member_mask, flag);
|
|
219
220
|
|
|
220
221
|
if (!HEAD_SEGMENTED)
|
|
221
222
|
{
|
|
@@ -232,7 +233,7 @@ struct WarpReduceSmem
|
|
|
232
233
|
}
|
|
233
234
|
|
|
234
235
|
// Find next flag
|
|
235
|
-
int next_flag =
|
|
236
|
+
int next_flag = ::cuda::std::countr_zero(warp_flags);
|
|
236
237
|
|
|
237
238
|
// Clip the next segment at the warp boundary if necessary
|
|
238
239
|
if (LOGICAL_WARP_THREADS != 32)
|