faiss 0.3.0 → 0.3.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +1 -2
- data/vendor/faiss/faiss/Clustering.cpp +39 -22
- data/vendor/faiss/faiss/Clustering.h +40 -21
- data/vendor/faiss/faiss/IVFlib.cpp +26 -12
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +40 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
- data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
- data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
- data/vendor/faiss/faiss/IndexHNSW.h +62 -49
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
- data/vendor/faiss/faiss/IndexIVF.h +46 -6
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
- data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +11 -11
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
- data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
- data/vendor/faiss/faiss/impl/HNSW.h +52 -30
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
- data/vendor/faiss/faiss/impl/io.cpp +23 -15
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
- data/vendor/faiss/faiss/index_factory.cpp +41 -20
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +147 -123
- data/vendor/faiss/faiss/utils/distances.h +86 -9
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +120 -7
- data/vendor/faiss/faiss/utils/utils.h +60 -20
- metadata +23 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -0,0 +1,282 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
/*
|
8
|
+
* Copyright (c) 2024, NVIDIA CORPORATION.
|
9
|
+
*
|
10
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
11
|
+
* you may not use this file except in compliance with the License.
|
12
|
+
* You may obtain a copy of the License at
|
13
|
+
*
|
14
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
15
|
+
*
|
16
|
+
* Unless required by applicable law or agreed to in writing, software
|
17
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
18
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19
|
+
* See the License for the specific language governing permissions and
|
20
|
+
* limitations under the License.
|
21
|
+
*/
|
22
|
+
|
23
|
+
#pragma once
|
24
|
+
|
25
|
+
#include <faiss/IndexIVF.h>
|
26
|
+
#include <faiss/gpu/GpuIndex.h>
|
27
|
+
#include <faiss/gpu/GpuIndexIVFPQ.h>
|
28
|
+
|
29
|
+
namespace faiss {
|
30
|
+
struct IndexHNSWCagra;
|
31
|
+
}
|
32
|
+
|
33
|
+
namespace faiss {
|
34
|
+
namespace gpu {
|
35
|
+
|
36
|
+
class RaftCagra;
|
37
|
+
|
38
|
+
enum class graph_build_algo {
|
39
|
+
/// Use IVF-PQ to build all-neighbors knn graph
|
40
|
+
IVF_PQ,
|
41
|
+
/// Experimental, use NN-Descent to build all-neighbors knn graph
|
42
|
+
NN_DESCENT
|
43
|
+
};
|
44
|
+
|
45
|
+
/// A type for specifying how PQ codebooks are created.
|
46
|
+
enum class codebook_gen { // NOLINT
|
47
|
+
PER_SUBSPACE = 0, // NOLINT
|
48
|
+
PER_CLUSTER = 1, // NOLINT
|
49
|
+
};
|
50
|
+
|
51
|
+
struct IVFPQBuildCagraConfig {
|
52
|
+
///
|
53
|
+
/// The number of inverted lists (clusters)
|
54
|
+
///
|
55
|
+
/// Hint: the number of vectors per cluster (`n_rows/n_lists`) should be
|
56
|
+
/// approximately 1,000 to 10,000.
|
57
|
+
|
58
|
+
uint32_t n_lists = 1024;
|
59
|
+
/// The number of iterations searching for kmeans centers (index building).
|
60
|
+
uint32_t kmeans_n_iters = 20;
|
61
|
+
/// The fraction of data to use during iterative kmeans building.
|
62
|
+
double kmeans_trainset_fraction = 0.5;
|
63
|
+
///
|
64
|
+
/// The bit length of the vector element after compression by PQ.
|
65
|
+
///
|
66
|
+
/// Possible values: [4, 5, 6, 7, 8].
|
67
|
+
///
|
68
|
+
/// Hint: the smaller the 'pq_bits', the smaller the index size and the
|
69
|
+
/// better the search performance, but the lower the recall.
|
70
|
+
|
71
|
+
uint32_t pq_bits = 8;
|
72
|
+
///
|
73
|
+
/// The dimensionality of the vector after compression by PQ. When zero, an
|
74
|
+
/// optimal value is selected using a heuristic.
|
75
|
+
///
|
76
|
+
/// NB: `pq_dim /// pq_bits` must be a multiple of 8.
|
77
|
+
///
|
78
|
+
/// Hint: a smaller 'pq_dim' results in a smaller index size and better
|
79
|
+
/// search performance, but lower recall. If 'pq_bits' is 8, 'pq_dim' can be
|
80
|
+
/// set to any number, but multiple of 8 are desirable for good performance.
|
81
|
+
/// If 'pq_bits' is not 8, 'pq_dim' should be a multiple of 8. For good
|
82
|
+
/// performance, it is desirable that 'pq_dim' is a multiple of 32. Ideally,
|
83
|
+
/// 'pq_dim' should be also a divisor of the dataset dim.
|
84
|
+
|
85
|
+
uint32_t pq_dim = 0;
|
86
|
+
/// How PQ codebooks are created.
|
87
|
+
codebook_gen codebook_kind = codebook_gen::PER_SUBSPACE;
|
88
|
+
///
|
89
|
+
/// Apply a random rotation matrix on the input data and queries even if
|
90
|
+
/// `dim % pq_dim == 0`.
|
91
|
+
///
|
92
|
+
/// Note: if `dim` is not multiple of `pq_dim`, a random rotation is always
|
93
|
+
/// applied to the input data and queries to transform the working space
|
94
|
+
/// from `dim` to `rot_dim`, which may be slightly larger than the original
|
95
|
+
/// space and and is a multiple of `pq_dim` (`rot_dim % pq_dim == 0`).
|
96
|
+
/// However, this transform is not necessary when `dim` is multiple of
|
97
|
+
/// `pq_dim`
|
98
|
+
/// (`dim == rot_dim`, hence no need in adding "extra" data columns /
|
99
|
+
/// features).
|
100
|
+
///
|
101
|
+
/// By default, if `dim == rot_dim`, the rotation transform is initialized
|
102
|
+
/// with the identity matrix. When `force_random_rotation == true`, a random
|
103
|
+
/// orthogonal transform matrix is generated regardless of the values of
|
104
|
+
/// `dim` and `pq_dim`.
|
105
|
+
|
106
|
+
bool force_random_rotation = false;
|
107
|
+
///
|
108
|
+
/// By default, the algorithm allocates more space than necessary for
|
109
|
+
/// individual clusters
|
110
|
+
/// (`list_data`). This allows to amortize the cost of memory allocation and
|
111
|
+
/// reduce the number of data copies during repeated calls to `extend`
|
112
|
+
/// (extending the database).
|
113
|
+
///
|
114
|
+
/// The alternative is the conservative allocation behavior; when enabled,
|
115
|
+
/// the algorithm always allocates the minimum amount of memory required to
|
116
|
+
/// store the given number of records. Set this flag to `true` if you prefer
|
117
|
+
/// to use as little GPU memory for the database as possible.
|
118
|
+
|
119
|
+
bool conservative_memory_allocation = false;
|
120
|
+
};
|
121
|
+
|
122
|
+
struct IVFPQSearchCagraConfig {
|
123
|
+
/// The number of clusters to search.
|
124
|
+
uint32_t n_probes = 20;
|
125
|
+
///
|
126
|
+
/// Data type of look up table to be created dynamically at search time.
|
127
|
+
///
|
128
|
+
/// Possible values: [CUDA_R_32F, CUDA_R_16F, CUDA_R_8U]
|
129
|
+
///
|
130
|
+
/// The use of low-precision types reduces the amount of shared memory
|
131
|
+
/// required at search time, so fast shared memory kernels can be used even
|
132
|
+
/// for datasets with large dimansionality. Note that the recall is slightly
|
133
|
+
/// degraded when low-precision type is selected.
|
134
|
+
|
135
|
+
cudaDataType_t lut_dtype = CUDA_R_32F;
|
136
|
+
///
|
137
|
+
/// Storage data type for distance/similarity computed at search time.
|
138
|
+
///
|
139
|
+
/// Possible values: [CUDA_R_16F, CUDA_R_32F]
|
140
|
+
///
|
141
|
+
/// If the performance limiter at search time is device memory access,
|
142
|
+
/// selecting FP16 will improve performance slightly.
|
143
|
+
|
144
|
+
cudaDataType_t internal_distance_dtype = CUDA_R_32F;
|
145
|
+
///
|
146
|
+
/// Preferred fraction of SM's unified memory / L1 cache to be used as
|
147
|
+
/// shared memory.
|
148
|
+
///
|
149
|
+
/// Possible values: [0.0 - 1.0] as a fraction of the
|
150
|
+
/// `sharedMemPerMultiprocessor`.
|
151
|
+
///
|
152
|
+
/// One wants to increase the carveout to make sure a good GPU occupancy for
|
153
|
+
/// the main search kernel, but not to keep it too high to leave some memory
|
154
|
+
/// to be used as L1 cache. Note, this value is interpreted only as a hint.
|
155
|
+
/// Moreover, a GPU usually allows only a fixed set of cache configurations,
|
156
|
+
/// so the provided value is rounded up to the nearest configuration. Refer
|
157
|
+
/// to the NVIDIA tuning guide for the target GPU architecture.
|
158
|
+
///
|
159
|
+
/// Note, this is a low-level tuning parameter that can have drastic
|
160
|
+
/// negative effects on the search performance if tweaked incorrectly.
|
161
|
+
|
162
|
+
double preferred_shmem_carveout = 1.0;
|
163
|
+
};
|
164
|
+
|
165
|
+
struct GpuIndexCagraConfig : public GpuIndexConfig {
|
166
|
+
/// Degree of input graph for pruning.
|
167
|
+
size_t intermediate_graph_degree = 128;
|
168
|
+
/// Degree of output graph.
|
169
|
+
size_t graph_degree = 64;
|
170
|
+
/// ANN algorithm to build knn graph.
|
171
|
+
graph_build_algo build_algo = graph_build_algo::IVF_PQ;
|
172
|
+
/// Number of Iterations to run if building with NN_DESCENT
|
173
|
+
size_t nn_descent_niter = 20;
|
174
|
+
|
175
|
+
IVFPQBuildCagraConfig* ivf_pq_params = nullptr;
|
176
|
+
IVFPQSearchCagraConfig* ivf_pq_search_params = nullptr;
|
177
|
+
};
|
178
|
+
|
179
|
+
enum class search_algo {
|
180
|
+
/// For large batch sizes.
|
181
|
+
SINGLE_CTA,
|
182
|
+
/// For small batch sizes.
|
183
|
+
MULTI_CTA,
|
184
|
+
MULTI_KERNEL,
|
185
|
+
AUTO
|
186
|
+
};
|
187
|
+
|
188
|
+
enum class hash_mode { HASH, SMALL, AUTO };
|
189
|
+
|
190
|
+
struct SearchParametersCagra : SearchParameters {
|
191
|
+
/// Maximum number of queries to search at the same time (batch size). Auto
|
192
|
+
/// select when 0.
|
193
|
+
size_t max_queries = 0;
|
194
|
+
|
195
|
+
/// Number of intermediate search results retained during the search.
|
196
|
+
///
|
197
|
+
/// This is the main knob to adjust trade off between accuracy and search
|
198
|
+
/// speed. Higher values improve the search accuracy.
|
199
|
+
|
200
|
+
size_t itopk_size = 64;
|
201
|
+
|
202
|
+
/// Upper limit of search iterations. Auto select when 0.
|
203
|
+
size_t max_iterations = 0;
|
204
|
+
|
205
|
+
// In the following we list additional search parameters for fine tuning.
|
206
|
+
// Reasonable default values are automatically chosen.
|
207
|
+
|
208
|
+
/// Which search implementation to use.
|
209
|
+
search_algo algo = search_algo::AUTO;
|
210
|
+
|
211
|
+
/// Number of threads used to calculate a single distance. 4, 8, 16, or 32.
|
212
|
+
|
213
|
+
size_t team_size = 0;
|
214
|
+
|
215
|
+
/// Number of graph nodes to select as the starting point for the search in
|
216
|
+
/// each iteration. aka search width?
|
217
|
+
size_t search_width = 1;
|
218
|
+
/// Lower limit of search iterations.
|
219
|
+
size_t min_iterations = 0;
|
220
|
+
|
221
|
+
/// Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0.
|
222
|
+
size_t thread_block_size = 0;
|
223
|
+
/// Hashmap type. Auto selection when AUTO.
|
224
|
+
hash_mode hashmap_mode = hash_mode::AUTO;
|
225
|
+
/// Lower limit of hashmap bit length. More than 8.
|
226
|
+
size_t hashmap_min_bitlen = 0;
|
227
|
+
/// Upper limit of hashmap fill rate. More than 0.1, less than 0.9.
|
228
|
+
float hashmap_max_fill_rate = 0.5;
|
229
|
+
|
230
|
+
/// Number of iterations of initial random seed node selection. 1 or more.
|
231
|
+
|
232
|
+
uint32_t num_random_samplings = 1;
|
233
|
+
/// Bit mask used for initial random seed node selection.
|
234
|
+
uint64_t seed = 0x128394;
|
235
|
+
};
|
236
|
+
|
237
|
+
struct GpuIndexCagra : public GpuIndex {
|
238
|
+
public:
|
239
|
+
GpuIndexCagra(
|
240
|
+
GpuResourcesProvider* provider,
|
241
|
+
int dims,
|
242
|
+
faiss::MetricType metric = faiss::METRIC_L2,
|
243
|
+
GpuIndexCagraConfig config = GpuIndexCagraConfig());
|
244
|
+
|
245
|
+
/// Trains CAGRA based on the given vector data
|
246
|
+
void train(idx_t n, const float* x) override;
|
247
|
+
|
248
|
+
/// Initialize ourselves from the given CPU index; will overwrite
|
249
|
+
/// all data in ourselves
|
250
|
+
void copyFrom(const faiss::IndexHNSWCagra* index);
|
251
|
+
|
252
|
+
/// Copy ourselves to the given CPU index; will overwrite all data
|
253
|
+
/// in the index instance
|
254
|
+
void copyTo(faiss::IndexHNSWCagra* index) const;
|
255
|
+
|
256
|
+
void reset() override;
|
257
|
+
|
258
|
+
std::vector<idx_t> get_knngraph() const;
|
259
|
+
|
260
|
+
protected:
|
261
|
+
bool addImplRequiresIDs_() const override;
|
262
|
+
|
263
|
+
void addImpl_(idx_t n, const float* x, const idx_t* ids) override;
|
264
|
+
|
265
|
+
/// Called from GpuIndex for search
|
266
|
+
void searchImpl_(
|
267
|
+
idx_t n,
|
268
|
+
const float* x,
|
269
|
+
int k,
|
270
|
+
float* distances,
|
271
|
+
idx_t* labels,
|
272
|
+
const SearchParameters* search_params) const override;
|
273
|
+
|
274
|
+
/// Our configuration options
|
275
|
+
const GpuIndexCagraConfig cagraConfig_;
|
276
|
+
|
277
|
+
/// Instance that we own; contains the inverted lists
|
278
|
+
std::shared_ptr<RaftCagra> index_;
|
279
|
+
};
|
280
|
+
|
281
|
+
} // namespace gpu
|
282
|
+
} // namespace faiss
|
@@ -24,15 +24,13 @@ namespace gpu {
|
|
24
24
|
class FlatIndex;
|
25
25
|
|
26
26
|
struct GpuIndexFlatConfig : public GpuIndexConfig {
|
27
|
-
inline GpuIndexFlatConfig() : useFloat16(false) {}
|
28
|
-
|
29
27
|
/// Whether or not data is stored as float16
|
30
|
-
bool useFloat16;
|
28
|
+
bool ALIGNED(8) useFloat16 = false;
|
31
29
|
|
32
30
|
/// Deprecated: no longer used
|
33
31
|
/// Previously used to indicate whether internal storage of vectors is
|
34
32
|
/// transposed
|
35
|
-
bool storeTransposed;
|
33
|
+
bool storeTransposed = false;
|
36
34
|
};
|
37
35
|
|
38
36
|
/// Wrapper around the GPU implementation that looks like
|
@@ -115,6 +113,8 @@ class GpuIndexFlat : public GpuIndex {
|
|
115
113
|
}
|
116
114
|
|
117
115
|
protected:
|
116
|
+
void resetIndex_(int dims);
|
117
|
+
|
118
118
|
/// Flat index does not require IDs as there is no storage available for
|
119
119
|
/// them
|
120
120
|
bool addImplRequiresIDs_() const override;
|
@@ -21,13 +21,17 @@ class GpuIndexFlat;
|
|
21
21
|
class IVFBase;
|
22
22
|
|
23
23
|
struct GpuIndexIVFConfig : public GpuIndexConfig {
|
24
|
-
inline GpuIndexIVFConfig() : indicesOptions(INDICES_64_BIT) {}
|
25
|
-
|
26
24
|
/// Index storage options for the GPU
|
27
|
-
IndicesOptions indicesOptions;
|
25
|
+
IndicesOptions indicesOptions = INDICES_64_BIT;
|
28
26
|
|
29
27
|
/// Configuration for the coarse quantizer object
|
30
28
|
GpuIndexFlatConfig flatConfig;
|
29
|
+
|
30
|
+
/// This flag controls the CPU fallback logic for coarse quantizer
|
31
|
+
/// component of the index. When set to false (default), the cloner will
|
32
|
+
/// throw an exception for indices not implemented on GPU. When set to
|
33
|
+
/// true, it will fallback to a CPU implementation.
|
34
|
+
bool allowCpuCoarseQuantizer = false;
|
31
35
|
};
|
32
36
|
|
33
37
|
/// Base class of all GPU IVF index types. This (for now) deliberately does not
|
@@ -75,10 +79,10 @@ class GpuIndexIVF : public GpuIndex, public IndexIVFInterface {
|
|
75
79
|
virtual void updateQuantizer() = 0;
|
76
80
|
|
77
81
|
/// Returns the number of inverted lists we're managing
|
78
|
-
idx_t getNumLists() const;
|
82
|
+
virtual idx_t getNumLists() const;
|
79
83
|
|
80
84
|
/// Returns the number of vectors present in a particular inverted list
|
81
|
-
idx_t getListLength(idx_t listId) const;
|
85
|
+
virtual idx_t getListLength(idx_t listId) const;
|
82
86
|
|
83
87
|
/// Return the encoded vector data contained in a particular inverted list,
|
84
88
|
/// for debugging purposes.
|
@@ -86,12 +90,13 @@ class GpuIndexIVF : public GpuIndex, public IndexIVFInterface {
|
|
86
90
|
/// GPU-side representation.
|
87
91
|
/// Otherwise, it is converted to the CPU format.
|
88
92
|
/// compliant format, while the native GPU format may differ.
|
89
|
-
std::vector<uint8_t> getListVectorData(
|
90
|
-
|
93
|
+
virtual std::vector<uint8_t> getListVectorData(
|
94
|
+
idx_t listId,
|
95
|
+
bool gpuFormat = false) const;
|
91
96
|
|
92
97
|
/// Return the vector indices contained in a particular inverted list, for
|
93
98
|
/// debugging purposes.
|
94
|
-
std::vector<idx_t> getListIndices(idx_t listId) const;
|
99
|
+
virtual std::vector<idx_t> getListIndices(idx_t listId) const;
|
95
100
|
|
96
101
|
void search_preassigned(
|
97
102
|
idx_t n,
|
@@ -123,7 +128,7 @@ class GpuIndexIVF : public GpuIndex, public IndexIVFInterface {
|
|
123
128
|
int getCurrentNProbe_(const SearchParameters* params) const;
|
124
129
|
void verifyIVFSettings_() const;
|
125
130
|
bool addImplRequiresIDs_() const override;
|
126
|
-
void trainQuantizer_(idx_t n, const float* x);
|
131
|
+
virtual void trainQuantizer_(idx_t n, const float* x);
|
127
132
|
|
128
133
|
/// Called from GpuIndex for add/add_with_ids
|
129
134
|
void addImpl_(idx_t n, const float* x, const idx_t* ids) override;
|
@@ -8,6 +8,8 @@
|
|
8
8
|
#pragma once
|
9
9
|
|
10
10
|
#include <faiss/gpu/GpuIndexIVF.h>
|
11
|
+
#include <faiss/impl/ScalarQuantizer.h>
|
12
|
+
|
11
13
|
#include <memory>
|
12
14
|
|
13
15
|
namespace faiss {
|
@@ -21,11 +23,9 @@ class IVFFlat;
|
|
21
23
|
class GpuIndexFlat;
|
22
24
|
|
23
25
|
struct GpuIndexIVFFlatConfig : public GpuIndexIVFConfig {
|
24
|
-
inline GpuIndexIVFFlatConfig() : interleavedLayout(true) {}
|
25
|
-
|
26
26
|
/// Use the alternative memory layout for the IVF lists
|
27
27
|
/// (currently the default)
|
28
|
-
bool interleavedLayout;
|
28
|
+
bool interleavedLayout = true;
|
29
29
|
};
|
30
30
|
|
31
31
|
/// Wrapper around the GPU implementation that looks like
|
@@ -87,6 +87,23 @@ class GpuIndexIVFFlat : public GpuIndexIVF {
|
|
87
87
|
/// Trains the coarse quantizer based on the given vector data
|
88
88
|
void train(idx_t n, const float* x) override;
|
89
89
|
|
90
|
+
void reconstruct_n(idx_t i0, idx_t n, float* out) const override;
|
91
|
+
|
92
|
+
protected:
|
93
|
+
/// Initialize appropriate index
|
94
|
+
void setIndex_(
|
95
|
+
GpuResources* resources,
|
96
|
+
int dim,
|
97
|
+
int nlist,
|
98
|
+
faiss::MetricType metric,
|
99
|
+
float metricArg,
|
100
|
+
bool useResidual,
|
101
|
+
/// Optional ScalarQuantizer
|
102
|
+
faiss::ScalarQuantizer* scalarQ,
|
103
|
+
bool interleavedLayout,
|
104
|
+
IndicesOptions indicesOptions,
|
105
|
+
MemorySpace space);
|
106
|
+
|
90
107
|
protected:
|
91
108
|
/// Our configuration options
|
92
109
|
const GpuIndexIVFFlatConfig ivfFlatConfig_;
|
@@ -23,24 +23,19 @@ class GpuIndexFlat;
|
|
23
23
|
class IVFPQ;
|
24
24
|
|
25
25
|
struct GpuIndexIVFPQConfig : public GpuIndexIVFConfig {
|
26
|
-
inline GpuIndexIVFPQConfig()
|
27
|
-
: useFloat16LookupTables(false),
|
28
|
-
usePrecomputedTables(false),
|
29
|
-
interleavedLayout(false),
|
30
|
-
useMMCodeDistance(false) {}
|
31
|
-
|
32
26
|
/// Whether or not float16 residual distance tables are used in the
|
33
27
|
/// list scanning kernels. When subQuantizers * 2^bitsPerCode >
|
34
28
|
/// 16384, this is required.
|
35
|
-
bool useFloat16LookupTables;
|
29
|
+
bool useFloat16LookupTables = false;
|
36
30
|
|
37
31
|
/// Whether or not we enable the precomputed table option for
|
38
32
|
/// search, which can substantially increase the memory requirement.
|
39
|
-
bool usePrecomputedTables;
|
33
|
+
bool usePrecomputedTables = false;
|
40
34
|
|
41
35
|
/// Use the alternative memory layout for the IVF lists
|
42
|
-
/// WARNING: this is a feature under development,
|
43
|
-
|
36
|
+
/// WARNING: this is a feature under development, and is only supported with
|
37
|
+
/// RAFT enabled for the index. Do not use if RAFT is not enabled.
|
38
|
+
bool interleavedLayout = false;
|
44
39
|
|
45
40
|
/// Use GEMM-backed computation of PQ code distances for the no precomputed
|
46
41
|
/// table version of IVFPQ.
|
@@ -50,7 +45,7 @@ struct GpuIndexIVFPQConfig : public GpuIndexIVFConfig {
|
|
50
45
|
/// Note that MM code distance is enabled automatically if one uses a number
|
51
46
|
/// of dimensions per sub-quantizer that is not natively specialized (an odd
|
52
47
|
/// number like 7 or so).
|
53
|
-
bool useMMCodeDistance;
|
48
|
+
bool useMMCodeDistance = false;
|
54
49
|
};
|
55
50
|
|
56
51
|
/// IVFPQ index for the GPU
|
@@ -139,6 +134,22 @@ class GpuIndexIVFPQ : public GpuIndexIVF {
|
|
139
134
|
ProductQuantizer pq;
|
140
135
|
|
141
136
|
protected:
|
137
|
+
/// Initialize appropriate index
|
138
|
+
void setIndex_(
|
139
|
+
GpuResources* resources,
|
140
|
+
int dim,
|
141
|
+
idx_t nlist,
|
142
|
+
faiss::MetricType metric,
|
143
|
+
float metricArg,
|
144
|
+
int numSubQuantizers,
|
145
|
+
int bitsPerSubQuantizer,
|
146
|
+
bool useFloat16LookupTables,
|
147
|
+
bool useMMCodeDistance,
|
148
|
+
bool interleavedLayout,
|
149
|
+
float* pqCentroidData,
|
150
|
+
IndicesOptions indicesOptions,
|
151
|
+
MemorySpace space);
|
152
|
+
|
142
153
|
/// Throws errors if configuration settings are improper
|
143
154
|
void verifyPQSettings_() const;
|
144
155
|
|
@@ -18,11 +18,9 @@ class IVFFlat;
|
|
18
18
|
class GpuIndexFlat;
|
19
19
|
|
20
20
|
struct GpuIndexIVFScalarQuantizerConfig : public GpuIndexIVFConfig {
|
21
|
-
inline GpuIndexIVFScalarQuantizerConfig() : interleavedLayout(true) {}
|
22
|
-
|
23
21
|
/// Use the alternative memory layout for the IVF lists
|
24
22
|
/// (currently the default)
|
25
|
-
bool interleavedLayout;
|
23
|
+
bool interleavedLayout = true;
|
26
24
|
};
|
27
25
|
|
28
26
|
/// Wrapper around the GPU implementation that looks like
|
@@ -4,6 +4,21 @@
|
|
4
4
|
* This source code is licensed under the MIT license found in the
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
|
+
/*
|
8
|
+
* Copyright (c) 2023, NVIDIA CORPORATION.
|
9
|
+
*
|
10
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
11
|
+
* you may not use this file except in compliance with the License.
|
12
|
+
* You may obtain a copy of the License at
|
13
|
+
*
|
14
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
15
|
+
*
|
16
|
+
* Unless required by applicable law or agreed to in writing, software
|
17
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
18
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19
|
+
* See the License for the specific language governing permissions and
|
20
|
+
* limitations under the License.
|
21
|
+
*/
|
7
22
|
|
8
23
|
#include <faiss/gpu/GpuResources.h>
|
9
24
|
#include <faiss/gpu/utils/DeviceUtils.h>
|
@@ -143,7 +158,7 @@ GpuMemoryReservation::~GpuMemoryReservation() {
|
|
143
158
|
// GpuResources
|
144
159
|
//
|
145
160
|
|
146
|
-
GpuResources::~GpuResources()
|
161
|
+
GpuResources::~GpuResources() = default;
|
147
162
|
|
148
163
|
cublasHandle_t GpuResources::getBlasHandleCurrentDevice() {
|
149
164
|
return getBlasHandle(getCurrentDevice());
|
@@ -153,6 +168,12 @@ cudaStream_t GpuResources::getDefaultStreamCurrentDevice() {
|
|
153
168
|
return getDefaultStream(getCurrentDevice());
|
154
169
|
}
|
155
170
|
|
171
|
+
#if defined USE_NVIDIA_RAFT
|
172
|
+
raft::device_resources& GpuResources::getRaftHandleCurrentDevice() {
|
173
|
+
return getRaftHandle(getCurrentDevice());
|
174
|
+
}
|
175
|
+
#endif
|
176
|
+
|
156
177
|
std::vector<cudaStream_t> GpuResources::getAlternateStreamsCurrentDevice() {
|
157
178
|
return getAlternateStreams(getCurrentDevice());
|
158
179
|
}
|
@@ -182,7 +203,7 @@ size_t GpuResources::getTempMemoryAvailableCurrentDevice() const {
|
|
182
203
|
// GpuResourcesProvider
|
183
204
|
//
|
184
205
|
|
185
|
-
GpuResourcesProvider::~GpuResourcesProvider()
|
206
|
+
GpuResourcesProvider::~GpuResourcesProvider() = default;
|
186
207
|
|
187
208
|
//
|
188
209
|
// GpuResourcesProviderFromResourceInstance
|
@@ -192,7 +213,7 @@ GpuResourcesProviderFromInstance::GpuResourcesProviderFromInstance(
|
|
192
213
|
std::shared_ptr<GpuResources> p)
|
193
214
|
: res_(p) {}
|
194
215
|
|
195
|
-
GpuResourcesProviderFromInstance::~GpuResourcesProviderFromInstance()
|
216
|
+
GpuResourcesProviderFromInstance::~GpuResourcesProviderFromInstance() = default;
|
196
217
|
|
197
218
|
std::shared_ptr<GpuResources> GpuResourcesProviderFromInstance::getResources() {
|
198
219
|
return res_;
|
@@ -4,16 +4,37 @@
|
|
4
4
|
* This source code is licensed under the MIT license found in the
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
|
+
/*
|
8
|
+
* Copyright (c) 2023, NVIDIA CORPORATION.
|
9
|
+
*
|
10
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
11
|
+
* you may not use this file except in compliance with the License.
|
12
|
+
* You may obtain a copy of the License at
|
13
|
+
*
|
14
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
15
|
+
*
|
16
|
+
* Unless required by applicable law or agreed to in writing, software
|
17
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
18
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19
|
+
* See the License for the specific language governing permissions and
|
20
|
+
* limitations under the License.
|
21
|
+
*/
|
7
22
|
|
8
23
|
#pragma once
|
9
24
|
|
10
25
|
#include <cublas_v2.h>
|
11
26
|
#include <cuda_runtime.h>
|
12
27
|
#include <faiss/impl/FaissAssert.h>
|
28
|
+
|
13
29
|
#include <memory>
|
14
30
|
#include <utility>
|
15
31
|
#include <vector>
|
16
32
|
|
33
|
+
#if defined USE_NVIDIA_RAFT
|
34
|
+
#include <raft/core/device_resources.hpp>
|
35
|
+
#include <rmm/mr/device/device_memory_resource.hpp>
|
36
|
+
#endif
|
37
|
+
|
17
38
|
namespace faiss {
|
18
39
|
namespace gpu {
|
19
40
|
|
@@ -82,11 +103,7 @@ std::string memorySpaceToString(MemorySpace s);
|
|
82
103
|
|
83
104
|
/// Information on what/where an allocation is
|
84
105
|
struct AllocInfo {
|
85
|
-
inline AllocInfo()
|
86
|
-
: type(AllocType::Other),
|
87
|
-
device(0),
|
88
|
-
space(MemorySpace::Device),
|
89
|
-
stream(nullptr) {}
|
106
|
+
inline AllocInfo() {}
|
90
107
|
|
91
108
|
inline AllocInfo(AllocType at, int dev, MemorySpace sp, cudaStream_t st)
|
92
109
|
: type(at), device(dev), space(sp), stream(st) {}
|
@@ -95,13 +112,13 @@ struct AllocInfo {
|
|
95
112
|
std::string toString() const;
|
96
113
|
|
97
114
|
/// The internal category of the allocation
|
98
|
-
AllocType type;
|
115
|
+
AllocType type = AllocType::Other;
|
99
116
|
|
100
117
|
/// The device on which the allocation is happening
|
101
|
-
int device;
|
118
|
+
int device = 0;
|
102
119
|
|
103
120
|
/// The memory space of the allocation
|
104
|
-
MemorySpace space;
|
121
|
+
MemorySpace space = MemorySpace::Device;
|
105
122
|
|
106
123
|
/// The stream on which new work on the memory will be ordered (e.g., if a
|
107
124
|
/// piece of memory cached and to be returned for this call was last used on
|
@@ -111,7 +128,7 @@ struct AllocInfo {
|
|
111
128
|
///
|
112
129
|
/// The memory manager guarantees that the returned memory is free to use
|
113
130
|
/// without data races on this stream specified.
|
114
|
-
cudaStream_t stream;
|
131
|
+
cudaStream_t stream = nullptr;
|
115
132
|
};
|
116
133
|
|
117
134
|
/// Create an AllocInfo for the current device with MemorySpace::Device
|
@@ -125,7 +142,7 @@ AllocInfo makeSpaceAlloc(AllocType at, MemorySpace sp, cudaStream_t st);
|
|
125
142
|
|
126
143
|
/// Information on what/where an allocation is, along with how big it should be
|
127
144
|
struct AllocRequest : public AllocInfo {
|
128
|
-
inline AllocRequest()
|
145
|
+
inline AllocRequest() {}
|
129
146
|
|
130
147
|
inline AllocRequest(const AllocInfo& info, size_t sz)
|
131
148
|
: AllocInfo(info), size(sz) {}
|
@@ -142,7 +159,11 @@ struct AllocRequest : public AllocInfo {
|
|
142
159
|
std::string toString() const;
|
143
160
|
|
144
161
|
/// The size in bytes of the allocation
|
145
|
-
size_t size;
|
162
|
+
size_t size = 0;
|
163
|
+
|
164
|
+
#if defined USE_NVIDIA_RAFT
|
165
|
+
rmm::mr::device_memory_resource* mr = nullptr;
|
166
|
+
#endif
|
146
167
|
};
|
147
168
|
|
148
169
|
/// A RAII object that manages a temporary memory request
|
@@ -190,6 +211,13 @@ class GpuResources {
|
|
190
211
|
/// given device
|
191
212
|
virtual cudaStream_t getDefaultStream(int device) = 0;
|
192
213
|
|
214
|
+
#if defined USE_NVIDIA_RAFT
|
215
|
+
/// Returns the raft handle for the given device which can be used to
|
216
|
+
/// make calls to other raft primitives.
|
217
|
+
virtual raft::device_resources& getRaftHandle(int device) = 0;
|
218
|
+
raft::device_resources& getRaftHandleCurrentDevice();
|
219
|
+
#endif
|
220
|
+
|
193
221
|
/// Overrides the default stream for a device to the user-supplied stream.
|
194
222
|
/// The resources object does not own this stream (i.e., it will not destroy
|
195
223
|
/// it).
|