faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -24,18 +24,17 @@ namespace {
|
|
|
24
24
|
|
|
25
25
|
typedef Index::idx_t idx_t;
|
|
26
26
|
|
|
27
|
-
|
|
28
27
|
// add translation to all valid labels
|
|
29
|
-
void translate_labels
|
|
30
|
-
|
|
31
|
-
|
|
28
|
+
void translate_labels(long n, idx_t* labels, long translation) {
|
|
29
|
+
if (translation == 0)
|
|
30
|
+
return;
|
|
32
31
|
for (long i = 0; i < n; i++) {
|
|
33
|
-
if(labels[i] < 0)
|
|
32
|
+
if (labels[i] < 0)
|
|
33
|
+
continue;
|
|
34
34
|
labels[i] += translation;
|
|
35
35
|
}
|
|
36
36
|
}
|
|
37
37
|
|
|
38
|
-
|
|
39
38
|
/** merge result tables from several shards.
|
|
40
39
|
* @param all_distances size nshard * n * k
|
|
41
40
|
* @param all_labels idem
|
|
@@ -43,296 +42,313 @@ void translate_labels (long n, idx_t *labels, long translation)
|
|
|
43
42
|
*/
|
|
44
43
|
|
|
45
44
|
template <class IndexClass, class C>
|
|
46
|
-
void
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
45
|
+
void merge_tables(
|
|
46
|
+
long n,
|
|
47
|
+
long k,
|
|
48
|
+
long nshard,
|
|
49
|
+
typename IndexClass::distance_t* distances,
|
|
50
|
+
idx_t* labels,
|
|
51
|
+
const std::vector<typename IndexClass::distance_t>& all_distances,
|
|
52
|
+
const std::vector<idx_t>& all_labels,
|
|
53
|
+
const std::vector<long>& translations) {
|
|
54
|
+
if (k == 0) {
|
|
55
|
+
return;
|
|
56
|
+
}
|
|
57
|
+
using distance_t = typename IndexClass::distance_t;
|
|
58
|
+
|
|
59
|
+
long stride = n * k;
|
|
59
60
|
#pragma omp parallel
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
61
|
+
{
|
|
62
|
+
std::vector<int> buf(2 * nshard);
|
|
63
|
+
int* pointer = buf.data();
|
|
64
|
+
int* shard_ids = pointer + nshard;
|
|
65
|
+
std::vector<distance_t> buf2(nshard);
|
|
66
|
+
distance_t* heap_vals = buf2.data();
|
|
66
67
|
#pragma omp for
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
68
|
+
for (long i = 0; i < n; i++) {
|
|
69
|
+
// the heap maps values to the shard where they are
|
|
70
|
+
// produced.
|
|
71
|
+
const distance_t* D_in = all_distances.data() + i * k;
|
|
72
|
+
const idx_t* I_in = all_labels.data() + i * k;
|
|
73
|
+
int heap_size = 0;
|
|
74
|
+
|
|
75
|
+
for (long s = 0; s < nshard; s++) {
|
|
76
|
+
pointer[s] = 0;
|
|
77
|
+
if (I_in[stride * s] >= 0) {
|
|
78
|
+
heap_push<C>(
|
|
79
|
+
++heap_size,
|
|
80
|
+
heap_vals,
|
|
81
|
+
shard_ids,
|
|
82
|
+
D_in[stride * s],
|
|
83
|
+
s);
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
distance_t* D = distances + i * k;
|
|
88
|
+
idx_t* I = labels + i * k;
|
|
89
|
+
|
|
90
|
+
for (int j = 0; j < k; j++) {
|
|
91
|
+
if (heap_size == 0) {
|
|
92
|
+
I[j] = -1;
|
|
93
|
+
D[j] = C::neutral();
|
|
94
|
+
} else {
|
|
95
|
+
// pop best element
|
|
96
|
+
int s = shard_ids[0];
|
|
97
|
+
int& p = pointer[s];
|
|
98
|
+
D[j] = heap_vals[0];
|
|
99
|
+
I[j] = I_in[stride * s + p] + translations[s];
|
|
100
|
+
|
|
101
|
+
heap_pop<C>(heap_size--, heap_vals, shard_ids);
|
|
102
|
+
p++;
|
|
103
|
+
if (p < k && I_in[stride * s + p] >= 0) {
|
|
104
|
+
heap_push<C>(
|
|
105
|
+
++heap_size,
|
|
106
|
+
heap_vals,
|
|
107
|
+
shard_ids,
|
|
108
|
+
D_in[stride * s + p],
|
|
109
|
+
s);
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
}
|
|
102
113
|
}
|
|
103
|
-
}
|
|
104
114
|
}
|
|
105
|
-
}
|
|
106
115
|
}
|
|
107
116
|
|
|
108
117
|
} // anonymous namespace
|
|
109
118
|
|
|
110
119
|
template <typename IndexT>
|
|
111
|
-
IndexShardsTemplate<IndexT>::IndexShardsTemplate(
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
}
|
|
120
|
+
IndexShardsTemplate<IndexT>::IndexShardsTemplate(
|
|
121
|
+
idx_t d,
|
|
122
|
+
bool threaded,
|
|
123
|
+
bool successive_ids)
|
|
124
|
+
: ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {}
|
|
117
125
|
|
|
118
126
|
template <typename IndexT>
|
|
119
|
-
IndexShardsTemplate<IndexT>::IndexShardsTemplate(
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
}
|
|
127
|
+
IndexShardsTemplate<IndexT>::IndexShardsTemplate(
|
|
128
|
+
int d,
|
|
129
|
+
bool threaded,
|
|
130
|
+
bool successive_ids)
|
|
131
|
+
: ThreadedIndex<IndexT>(d, threaded), successive_ids(successive_ids) {}
|
|
125
132
|
|
|
126
133
|
template <typename IndexT>
|
|
127
|
-
IndexShardsTemplate<IndexT>::IndexShardsTemplate(
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
}
|
|
134
|
+
IndexShardsTemplate<IndexT>::IndexShardsTemplate(
|
|
135
|
+
bool threaded,
|
|
136
|
+
bool successive_ids)
|
|
137
|
+
: ThreadedIndex<IndexT>(threaded), successive_ids(successive_ids) {}
|
|
132
138
|
|
|
133
139
|
template <typename IndexT>
|
|
134
|
-
void
|
|
135
|
-
|
|
136
|
-
syncWithSubIndexes();
|
|
140
|
+
void IndexShardsTemplate<IndexT>::onAfterAddIndex(IndexT* index /* unused */) {
|
|
141
|
+
syncWithSubIndexes();
|
|
137
142
|
}
|
|
138
143
|
|
|
139
144
|
template <typename IndexT>
|
|
140
|
-
void
|
|
141
|
-
|
|
142
|
-
|
|
145
|
+
void IndexShardsTemplate<IndexT>::onAfterRemoveIndex(
|
|
146
|
+
IndexT* index /* unused */) {
|
|
147
|
+
syncWithSubIndexes();
|
|
143
148
|
}
|
|
144
149
|
|
|
145
150
|
// FIXME: assumes that nothing is currently running on the sub-indexes, which is
|
|
146
151
|
// true with the normal API, but should use the runOnIndex API instead
|
|
147
152
|
template <typename IndexT>
|
|
148
|
-
void
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
}
|
|
153
|
+
void IndexShardsTemplate<IndexT>::syncWithSubIndexes() {
|
|
154
|
+
if (!this->count()) {
|
|
155
|
+
this->is_trained = false;
|
|
156
|
+
this->ntotal = 0;
|
|
157
|
+
|
|
158
|
+
return;
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
auto firstIndex = this->at(0);
|
|
162
|
+
this->metric_type = firstIndex->metric_type;
|
|
163
|
+
this->is_trained = firstIndex->is_trained;
|
|
164
|
+
this->ntotal = firstIndex->ntotal;
|
|
165
|
+
|
|
166
|
+
for (int i = 1; i < this->count(); ++i) {
|
|
167
|
+
auto index = this->at(i);
|
|
168
|
+
FAISS_THROW_IF_NOT(this->metric_type == index->metric_type);
|
|
169
|
+
FAISS_THROW_IF_NOT(this->d == index->d);
|
|
170
|
+
FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
|
|
171
|
+
|
|
172
|
+
this->ntotal += index->ntotal;
|
|
173
|
+
}
|
|
170
174
|
}
|
|
171
175
|
|
|
172
176
|
// No metric_type for IndexBinary
|
|
173
177
|
template <>
|
|
174
|
-
void
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
}
|
|
178
|
+
void IndexShardsTemplate<IndexBinary>::syncWithSubIndexes() {
|
|
179
|
+
if (!this->count()) {
|
|
180
|
+
this->is_trained = false;
|
|
181
|
+
this->ntotal = 0;
|
|
182
|
+
|
|
183
|
+
return;
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
auto firstIndex = this->at(0);
|
|
187
|
+
this->is_trained = firstIndex->is_trained;
|
|
188
|
+
this->ntotal = firstIndex->ntotal;
|
|
189
|
+
|
|
190
|
+
for (int i = 1; i < this->count(); ++i) {
|
|
191
|
+
auto index = this->at(i);
|
|
192
|
+
FAISS_THROW_IF_NOT(this->d == index->d);
|
|
193
|
+
FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
|
|
194
|
+
|
|
195
|
+
this->ntotal += index->ntotal;
|
|
196
|
+
}
|
|
194
197
|
}
|
|
195
198
|
|
|
196
199
|
template <typename IndexT>
|
|
197
|
-
void
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
if (index->verbose) {
|
|
209
|
-
printf("end train shard %d\n", no);
|
|
210
|
-
}
|
|
200
|
+
void IndexShardsTemplate<IndexT>::train(idx_t n, const component_t* x) {
|
|
201
|
+
auto fn = [n, x](int no, IndexT* index) {
|
|
202
|
+
if (index->verbose) {
|
|
203
|
+
printf("begin train shard %d on %" PRId64 " points\n", no, n);
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
index->train(n, x);
|
|
207
|
+
|
|
208
|
+
if (index->verbose) {
|
|
209
|
+
printf("end train shard %d\n", no);
|
|
210
|
+
}
|
|
211
211
|
};
|
|
212
212
|
|
|
213
|
-
|
|
214
|
-
|
|
213
|
+
this->runOnIndex(fn);
|
|
214
|
+
syncWithSubIndexes();
|
|
215
215
|
}
|
|
216
216
|
|
|
217
217
|
template <typename IndexT>
|
|
218
|
-
void
|
|
219
|
-
|
|
220
|
-
const component_t *x) {
|
|
221
|
-
add_with_ids(n, x, nullptr);
|
|
218
|
+
void IndexShardsTemplate<IndexT>::add(idx_t n, const component_t* x) {
|
|
219
|
+
add_with_ids(n, x, nullptr);
|
|
222
220
|
}
|
|
223
221
|
|
|
224
222
|
template <typename IndexT>
|
|
225
|
-
void
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
idx_t nshard = this->count();
|
|
244
|
-
const idx_t *ids = xids;
|
|
245
|
-
|
|
246
|
-
std::vector<idx_t> aids;
|
|
247
|
-
|
|
248
|
-
if (!ids && !successive_ids) {
|
|
249
|
-
aids.resize(n);
|
|
250
|
-
|
|
251
|
-
for (idx_t i = 0; i < n; i++) {
|
|
252
|
-
aids[i] = this->ntotal + i;
|
|
223
|
+
void IndexShardsTemplate<IndexT>::add_with_ids(
|
|
224
|
+
idx_t n,
|
|
225
|
+
const component_t* x,
|
|
226
|
+
const idx_t* xids) {
|
|
227
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
228
|
+
!(successive_ids && xids),
|
|
229
|
+
"It makes no sense to pass in ids and "
|
|
230
|
+
"request them to be shifted");
|
|
231
|
+
|
|
232
|
+
if (successive_ids) {
|
|
233
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
234
|
+
!xids,
|
|
235
|
+
"It makes no sense to pass in ids and "
|
|
236
|
+
"request them to be shifted");
|
|
237
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
238
|
+
this->ntotal == 0,
|
|
239
|
+
"when adding to IndexShards with sucessive_ids, "
|
|
240
|
+
"only add() in a single pass is supported");
|
|
253
241
|
}
|
|
254
242
|
|
|
255
|
-
|
|
256
|
-
|
|
243
|
+
idx_t nshard = this->count();
|
|
244
|
+
const idx_t* ids = xids;
|
|
245
|
+
|
|
246
|
+
std::vector<idx_t> aids;
|
|
247
|
+
|
|
248
|
+
if (!ids && !successive_ids) {
|
|
249
|
+
aids.resize(n);
|
|
250
|
+
|
|
251
|
+
for (idx_t i = 0; i < n; i++) {
|
|
252
|
+
aids[i] = this->ntotal + i;
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
ids = aids.data();
|
|
256
|
+
}
|
|
257
257
|
|
|
258
|
-
|
|
259
|
-
|
|
258
|
+
size_t components_per_vec =
|
|
259
|
+
sizeof(component_t) == 1 ? (this->d + 7) / 8 : this->d;
|
|
260
260
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
auto x0 = x + i0 * components_per_vec;
|
|
261
|
+
auto fn = [n, ids, x, nshard, components_per_vec](int no, IndexT* index) {
|
|
262
|
+
idx_t i0 = (idx_t)no * n / nshard;
|
|
263
|
+
idx_t i1 = ((idx_t)no + 1) * n / nshard;
|
|
264
|
+
auto x0 = x + i0 * components_per_vec;
|
|
266
265
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
266
|
+
if (index->verbose) {
|
|
267
|
+
printf("begin add shard %d on %" PRId64 " points\n", no, n);
|
|
268
|
+
}
|
|
270
269
|
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
270
|
+
if (ids) {
|
|
271
|
+
index->add_with_ids(i1 - i0, x0, ids + i0);
|
|
272
|
+
} else {
|
|
273
|
+
index->add(i1 - i0, x0);
|
|
274
|
+
}
|
|
276
275
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
276
|
+
if (index->verbose) {
|
|
277
|
+
printf("end add shard %d on %" PRId64 " points\n", no, i1 - i0);
|
|
278
|
+
}
|
|
280
279
|
};
|
|
281
280
|
|
|
282
|
-
|
|
283
|
-
|
|
281
|
+
this->runOnIndex(fn);
|
|
282
|
+
syncWithSubIndexes();
|
|
284
283
|
}
|
|
285
284
|
|
|
286
285
|
template <typename IndexT>
|
|
287
|
-
void
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
286
|
+
void IndexShardsTemplate<IndexT>::search(
|
|
287
|
+
idx_t n,
|
|
288
|
+
const component_t* x,
|
|
289
|
+
idx_t k,
|
|
290
|
+
distance_t* distances,
|
|
291
|
+
idx_t* labels) const {
|
|
292
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
293
|
+
|
|
294
|
+
long nshard = this->count();
|
|
295
|
+
|
|
296
|
+
std::vector<distance_t> all_distances(nshard * k * n);
|
|
297
|
+
std::vector<idx_t> all_labels(nshard * k * n);
|
|
298
|
+
|
|
299
|
+
auto fn = [n, k, x, &all_distances, &all_labels](
|
|
300
|
+
int no, const IndexT* index) {
|
|
301
|
+
if (index->verbose) {
|
|
302
|
+
printf("begin query shard %d on %" PRId64 " points\n", no, n);
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
index->search(
|
|
306
|
+
n,
|
|
307
|
+
x,
|
|
308
|
+
k,
|
|
309
|
+
all_distances.data() + no * k * n,
|
|
310
|
+
all_labels.data() + no * k * n);
|
|
311
|
+
|
|
312
|
+
if (index->verbose) {
|
|
313
|
+
printf("end query shard %d\n", no);
|
|
314
|
+
}
|
|
311
315
|
};
|
|
312
316
|
|
|
313
|
-
|
|
317
|
+
this->runOnIndex(fn);
|
|
318
|
+
|
|
319
|
+
std::vector<long> translations(nshard, 0);
|
|
314
320
|
|
|
315
|
-
|
|
321
|
+
// Because we just called runOnIndex above, it is safe to access the
|
|
322
|
+
// sub-index ntotal here
|
|
323
|
+
if (successive_ids) {
|
|
324
|
+
translations[0] = 0;
|
|
316
325
|
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
326
|
+
for (int s = 0; s + 1 < nshard; s++) {
|
|
327
|
+
translations[s + 1] = translations[s] + this->at(s)->ntotal;
|
|
328
|
+
}
|
|
329
|
+
}
|
|
321
330
|
|
|
322
|
-
|
|
323
|
-
|
|
331
|
+
if (this->metric_type == METRIC_L2) {
|
|
332
|
+
merge_tables<IndexT, CMin<distance_t, int>>(
|
|
333
|
+
n,
|
|
334
|
+
k,
|
|
335
|
+
nshard,
|
|
336
|
+
distances,
|
|
337
|
+
labels,
|
|
338
|
+
all_distances,
|
|
339
|
+
all_labels,
|
|
340
|
+
translations);
|
|
341
|
+
} else {
|
|
342
|
+
merge_tables<IndexT, CMax<distance_t, int>>(
|
|
343
|
+
n,
|
|
344
|
+
k,
|
|
345
|
+
nshard,
|
|
346
|
+
distances,
|
|
347
|
+
labels,
|
|
348
|
+
all_distances,
|
|
349
|
+
all_labels,
|
|
350
|
+
translations);
|
|
324
351
|
}
|
|
325
|
-
}
|
|
326
|
-
|
|
327
|
-
if (this->metric_type == METRIC_L2) {
|
|
328
|
-
merge_tables<IndexT, CMin<distance_t, int>>(
|
|
329
|
-
n, k, nshard, distances, labels,
|
|
330
|
-
all_distances, all_labels, translations);
|
|
331
|
-
} else {
|
|
332
|
-
merge_tables<IndexT, CMax<distance_t, int>>(
|
|
333
|
-
n, k, nshard, distances, labels,
|
|
334
|
-
all_distances, all_labels, translations);
|
|
335
|
-
}
|
|
336
352
|
}
|
|
337
353
|
|
|
338
354
|
// explicit instanciations
|