faiss 0.3.0 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +10 -3
@@ -559,15 +559,13 @@ struct simd16uint16 {
|
|
559
559
|
}
|
560
560
|
|
561
561
|
// Checks whether the other holds exactly the same bytes.
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
return equal0 && equal1;
|
562
|
+
template <typename T>
|
563
|
+
bool is_same_as(T other) const {
|
564
|
+
const auto o = detail::simdlib::reinterpret_u16(other.data);
|
565
|
+
const auto equals = detail::simdlib::binary_func(data, o)
|
566
|
+
.template call<&vceqq_u16>();
|
567
|
+
const auto equal = vandq_u16(equals.val[0], equals.val[1]);
|
568
|
+
return vminvq_u16(equal) == 0xffffu;
|
571
569
|
}
|
572
570
|
|
573
571
|
simd16uint16 operator~() const {
|
@@ -689,13 +687,12 @@ inline void cmplt_min_max_fast(
|
|
689
687
|
simd16uint16& minIndices,
|
690
688
|
simd16uint16& maxValues,
|
691
689
|
simd16uint16& maxIndices) {
|
692
|
-
const uint16x8x2_t comparison =
|
693
|
-
|
694
|
-
|
690
|
+
const uint16x8x2_t comparison =
|
691
|
+
detail::simdlib::binary_func(
|
692
|
+
candidateValues.data, currentValues.data)
|
693
|
+
.call<&vcltq_u16>();
|
695
694
|
|
696
|
-
minValues
|
697
|
-
vminq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
|
698
|
-
vminq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
|
695
|
+
minValues = min(candidateValues, currentValues);
|
699
696
|
minIndices.data = uint16x8x2_t{
|
700
697
|
vbslq_u16(
|
701
698
|
comparison.val[0],
|
@@ -706,9 +703,7 @@ inline void cmplt_min_max_fast(
|
|
706
703
|
candidateIndices.data.val[1],
|
707
704
|
currentIndices.data.val[1])};
|
708
705
|
|
709
|
-
maxValues
|
710
|
-
vmaxq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
|
711
|
-
vmaxq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
|
706
|
+
maxValues = max(candidateValues, currentValues);
|
712
707
|
maxIndices.data = uint16x8x2_t{
|
713
708
|
vbslq_u16(
|
714
709
|
comparison.val[0],
|
@@ -869,13 +864,13 @@ struct simd32uint8 {
|
|
869
864
|
}
|
870
865
|
|
871
866
|
// Checks whether the other holds exactly the same bytes.
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
const
|
876
|
-
|
877
|
-
|
878
|
-
return
|
867
|
+
template <typename T>
|
868
|
+
bool is_same_as(T other) const {
|
869
|
+
const auto o = detail::simdlib::reinterpret_u8(other.data);
|
870
|
+
const auto equals = detail::simdlib::binary_func(data, o)
|
871
|
+
.template call<&vceqq_u8>();
|
872
|
+
const auto equal = vandq_u8(equals.val[0], equals.val[1]);
|
873
|
+
return vminvq_u8(equal) == 0xffu;
|
879
874
|
}
|
880
875
|
};
|
881
876
|
|
@@ -960,27 +955,28 @@ struct simd8uint32 {
|
|
960
955
|
return *this;
|
961
956
|
}
|
962
957
|
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
|
967
|
-
return vminvq_u32(equal) == 0xffffffff;
|
958
|
+
simd8uint32 operator==(simd8uint32 other) const {
|
959
|
+
return simd8uint32{detail::simdlib::binary_func(data, other.data)
|
960
|
+
.call<&vceqq_u32>()};
|
968
961
|
}
|
969
962
|
|
970
|
-
|
971
|
-
return
|
963
|
+
simd8uint32 operator~() const {
|
964
|
+
return simd8uint32{
|
965
|
+
detail::simdlib::unary_func(data).call<&vmvnq_u32>()};
|
972
966
|
}
|
973
967
|
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
(vminvq_u32(vceqq_u32(data.val[0], other.data.val[0])) ==
|
978
|
-
0xffffffff);
|
979
|
-
const bool equal1 =
|
980
|
-
(vminvq_u32(vceqq_u32(data.val[1], other.data.val[1])) ==
|
981
|
-
0xffffffff);
|
968
|
+
simd8uint32 operator!=(simd8uint32 other) const {
|
969
|
+
return ~(*this == other);
|
970
|
+
}
|
982
971
|
|
983
|
-
|
972
|
+
// Checks whether the other holds exactly the same bytes.
|
973
|
+
template <typename T>
|
974
|
+
bool is_same_as(T other) const {
|
975
|
+
const auto o = detail::simdlib::reinterpret_u32(other.data);
|
976
|
+
const auto equals = detail::simdlib::binary_func(data, o)
|
977
|
+
.template call<&vceqq_u32>();
|
978
|
+
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
|
979
|
+
return vminvq_u32(equal) == 0xffffffffu;
|
984
980
|
}
|
985
981
|
|
986
982
|
void clear() {
|
@@ -1053,13 +1049,14 @@ inline void cmplt_min_max_fast(
|
|
1053
1049
|
simd8uint32& minIndices,
|
1054
1050
|
simd8uint32& maxValues,
|
1055
1051
|
simd8uint32& maxIndices) {
|
1056
|
-
const uint32x4x2_t comparison =
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1052
|
+
const uint32x4x2_t comparison =
|
1053
|
+
detail::simdlib::binary_func(
|
1054
|
+
candidateValues.data, currentValues.data)
|
1055
|
+
.call<&vcltq_u32>();
|
1056
|
+
|
1057
|
+
minValues.data = detail::simdlib::binary_func(
|
1058
|
+
candidateValues.data, currentValues.data)
|
1059
|
+
.call<&vminq_u32>();
|
1063
1060
|
minIndices.data = uint32x4x2_t{
|
1064
1061
|
vbslq_u32(
|
1065
1062
|
comparison.val[0],
|
@@ -1070,9 +1067,9 @@ inline void cmplt_min_max_fast(
|
|
1070
1067
|
candidateIndices.data.val[1],
|
1071
1068
|
currentIndices.data.val[1])};
|
1072
1069
|
|
1073
|
-
maxValues.data =
|
1074
|
-
|
1075
|
-
|
1070
|
+
maxValues.data = detail::simdlib::binary_func(
|
1071
|
+
candidateValues.data, currentValues.data)
|
1072
|
+
.call<&vmaxq_u32>();
|
1076
1073
|
maxIndices.data = uint32x4x2_t{
|
1077
1074
|
vbslq_u32(
|
1078
1075
|
comparison.val[0],
|
@@ -1167,28 +1164,25 @@ struct simd8float32 {
|
|
1167
1164
|
return *this;
|
1168
1165
|
}
|
1169
1166
|
|
1170
|
-
|
1171
|
-
|
1167
|
+
simd8uint32 operator==(simd8float32 other) const {
|
1168
|
+
return simd8uint32{
|
1172
1169
|
detail::simdlib::binary_func<::uint32x4x2_t>(data, other.data)
|
1173
|
-
.call<&vceqq_f32>();
|
1174
|
-
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
|
1175
|
-
return vminvq_u32(equal) == 0xffffffff;
|
1170
|
+
.call<&vceqq_f32>()};
|
1176
1171
|
}
|
1177
1172
|
|
1178
|
-
|
1179
|
-
return
|
1173
|
+
simd8uint32 operator!=(simd8float32 other) const {
|
1174
|
+
return ~(*this == other);
|
1180
1175
|
}
|
1181
1176
|
|
1182
1177
|
// Checks whether the other holds exactly the same bytes.
|
1183
|
-
|
1184
|
-
|
1185
|
-
|
1186
|
-
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1191
|
-
return equal0 && equal1;
|
1178
|
+
template <typename T>
|
1179
|
+
bool is_same_as(T other) const {
|
1180
|
+
const auto o = detail::simdlib::reinterpret_f32(other.data);
|
1181
|
+
const auto equals =
|
1182
|
+
detail::simdlib::binary_func<::uint32x4x2_t>(data, o)
|
1183
|
+
.template call<&vceqq_f32>();
|
1184
|
+
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
|
1185
|
+
return vminvq_u32(equal) == 0xffffffffu;
|
1192
1186
|
}
|
1193
1187
|
|
1194
1188
|
std::string tostring() const {
|
@@ -1302,13 +1296,14 @@ inline void cmplt_min_max_fast(
|
|
1302
1296
|
simd8uint32& minIndices,
|
1303
1297
|
simd8float32& maxValues,
|
1304
1298
|
simd8uint32& maxIndices) {
|
1305
|
-
const uint32x4x2_t comparison =
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1299
|
+
const uint32x4x2_t comparison =
|
1300
|
+
detail::simdlib::binary_func<::uint32x4x2_t>(
|
1301
|
+
candidateValues.data, currentValues.data)
|
1302
|
+
.call<&vcltq_f32>();
|
1303
|
+
|
1304
|
+
minValues.data = detail::simdlib::binary_func(
|
1305
|
+
candidateValues.data, currentValues.data)
|
1306
|
+
.call<&vminq_f32>();
|
1312
1307
|
minIndices.data = uint32x4x2_t{
|
1313
1308
|
vbslq_u32(
|
1314
1309
|
comparison.val[0],
|
@@ -1319,9 +1314,9 @@ inline void cmplt_min_max_fast(
|
|
1319
1314
|
candidateIndices.data.val[1],
|
1320
1315
|
currentIndices.data.val[1])};
|
1321
1316
|
|
1322
|
-
maxValues.data =
|
1323
|
-
|
1324
|
-
|
1317
|
+
maxValues.data = detail::simdlib::binary_func(
|
1318
|
+
candidateValues.data, currentValues.data)
|
1319
|
+
.call<&vmaxq_f32>();
|
1325
1320
|
maxIndices.data = uint32x4x2_t{
|
1326
1321
|
vbslq_u32(
|
1327
1322
|
comparison.val[0],
|
@@ -123,7 +123,7 @@ void parallel_merge(
|
|
123
123
|
}
|
124
124
|
}
|
125
125
|
|
126
|
-
}
|
126
|
+
} // namespace
|
127
127
|
|
128
128
|
void fvec_argsort(size_t n, const float* vals, size_t* perm) {
|
129
129
|
for (size_t i = 0; i < n; i++) {
|
@@ -544,7 +544,6 @@ void bucket_sort_inplace_parallel(
|
|
544
544
|
|
545
545
|
// in this loop, we write elements collected in the previous round
|
546
546
|
// and collect the elements that are overwritten for the next round
|
547
|
-
size_t tot_written = 0;
|
548
547
|
int round = 0;
|
549
548
|
for (;;) {
|
550
549
|
#pragma omp barrier
|
@@ -554,9 +553,6 @@ void bucket_sort_inplace_parallel(
|
|
554
553
|
n_to_write += to_write_2.lims.back();
|
555
554
|
}
|
556
555
|
|
557
|
-
tot_written += n_to_write;
|
558
|
-
// assert(tot_written <= nval);
|
559
|
-
|
560
556
|
#pragma omp master
|
561
557
|
{
|
562
558
|
if (verbose >= 1) {
|
@@ -689,4 +685,143 @@ void matrix_bucket_sort_inplace(
|
|
689
685
|
}
|
690
686
|
}
|
691
687
|
|
688
|
+
/** Hashtable implementation for int64 -> int64 with external storage
|
689
|
+
* implemented for speed and parallel processing.
|
690
|
+
*/
|
691
|
+
|
692
|
+
namespace {
|
693
|
+
|
694
|
+
int log2_capacity_to_log2_nbucket(int log2_capacity) {
|
695
|
+
return log2_capacity < 12 ? 0
|
696
|
+
: log2_capacity < 20 ? log2_capacity - 12
|
697
|
+
: 10;
|
698
|
+
}
|
699
|
+
|
700
|
+
// https://bigprimes.org/
|
701
|
+
int64_t bigprime = 8955327411143;
|
702
|
+
|
703
|
+
inline int64_t hash_function(int64_t x) {
|
704
|
+
return (x * 1000003) % bigprime;
|
705
|
+
}
|
706
|
+
|
707
|
+
} // anonymous namespace
|
708
|
+
|
709
|
+
void hashtable_int64_to_int64_init(int log2_capacity, int64_t* tab) {
|
710
|
+
size_t capacity = (size_t)1 << log2_capacity;
|
711
|
+
#pragma omp parallel for
|
712
|
+
for (int64_t i = 0; i < capacity; i++) {
|
713
|
+
tab[2 * i] = -1;
|
714
|
+
tab[2 * i + 1] = -1;
|
715
|
+
}
|
716
|
+
}
|
717
|
+
|
718
|
+
void hashtable_int64_to_int64_add(
|
719
|
+
int log2_capacity,
|
720
|
+
int64_t* tab,
|
721
|
+
size_t n,
|
722
|
+
const int64_t* keys,
|
723
|
+
const int64_t* vals) {
|
724
|
+
size_t capacity = (size_t)1 << log2_capacity;
|
725
|
+
std::vector<int64_t> hk(n);
|
726
|
+
std::vector<uint64_t> bucket_no(n);
|
727
|
+
int64_t mask = capacity - 1;
|
728
|
+
int log2_nbucket = log2_capacity_to_log2_nbucket(log2_capacity);
|
729
|
+
size_t nbucket = (size_t)1 << log2_nbucket;
|
730
|
+
|
731
|
+
#pragma omp parallel for
|
732
|
+
for (int64_t i = 0; i < n; i++) {
|
733
|
+
hk[i] = hash_function(keys[i]) & mask;
|
734
|
+
bucket_no[i] = hk[i] >> (log2_capacity - log2_nbucket);
|
735
|
+
}
|
736
|
+
|
737
|
+
std::vector<int64_t> lims(nbucket + 1);
|
738
|
+
std::vector<int64_t> perm(n);
|
739
|
+
bucket_sort(
|
740
|
+
n,
|
741
|
+
bucket_no.data(),
|
742
|
+
nbucket,
|
743
|
+
lims.data(),
|
744
|
+
perm.data(),
|
745
|
+
omp_get_max_threads());
|
746
|
+
|
747
|
+
int num_errors = 0;
|
748
|
+
#pragma omp parallel for reduction(+ : num_errors)
|
749
|
+
for (int64_t bucket = 0; bucket < nbucket; bucket++) {
|
750
|
+
size_t k0 = bucket << (log2_capacity - log2_nbucket);
|
751
|
+
size_t k1 = (bucket + 1) << (log2_capacity - log2_nbucket);
|
752
|
+
|
753
|
+
for (size_t i = lims[bucket]; i < lims[bucket + 1]; i++) {
|
754
|
+
int64_t j = perm[i];
|
755
|
+
assert(bucket_no[j] == bucket);
|
756
|
+
assert(hk[j] >= k0 && hk[j] < k1);
|
757
|
+
size_t slot = hk[j];
|
758
|
+
for (;;) {
|
759
|
+
if (tab[slot * 2] == -1) { // found!
|
760
|
+
tab[slot * 2] = keys[j];
|
761
|
+
tab[slot * 2 + 1] = vals[j];
|
762
|
+
break;
|
763
|
+
} else if (tab[slot * 2] == keys[j]) { // overwrite!
|
764
|
+
tab[slot * 2 + 1] = vals[j];
|
765
|
+
break;
|
766
|
+
}
|
767
|
+
slot++;
|
768
|
+
if (slot == k1) {
|
769
|
+
slot = k0;
|
770
|
+
}
|
771
|
+
if (slot == hk[j]) { // no free slot left in bucket
|
772
|
+
num_errors++;
|
773
|
+
break;
|
774
|
+
}
|
775
|
+
}
|
776
|
+
if (num_errors > 0) {
|
777
|
+
break;
|
778
|
+
}
|
779
|
+
}
|
780
|
+
}
|
781
|
+
FAISS_THROW_IF_NOT_MSG(num_errors == 0, "hashtable capacity exhausted");
|
782
|
+
}
|
783
|
+
|
784
|
+
void hashtable_int64_to_int64_lookup(
|
785
|
+
int log2_capacity,
|
786
|
+
const int64_t* tab,
|
787
|
+
size_t n,
|
788
|
+
const int64_t* keys,
|
789
|
+
int64_t* vals) {
|
790
|
+
size_t capacity = (size_t)1 << log2_capacity;
|
791
|
+
std::vector<int64_t> hk(n), bucket_no(n);
|
792
|
+
int64_t mask = capacity - 1;
|
793
|
+
int log2_nbucket = log2_capacity_to_log2_nbucket(log2_capacity);
|
794
|
+
|
795
|
+
#pragma omp parallel for
|
796
|
+
for (int64_t i = 0; i < n; i++) {
|
797
|
+
int64_t k = keys[i];
|
798
|
+
int64_t hk = hash_function(k) & mask;
|
799
|
+
size_t slot = hk;
|
800
|
+
|
801
|
+
if (tab[2 * slot] == -1) { // not in table
|
802
|
+
vals[i] = -1;
|
803
|
+
} else if (tab[2 * slot] == k) { // found!
|
804
|
+
vals[i] = tab[2 * slot + 1];
|
805
|
+
} else { // need to search in [k0, k1)
|
806
|
+
size_t bucket = hk >> (log2_capacity - log2_nbucket);
|
807
|
+
size_t k0 = bucket << (log2_capacity - log2_nbucket);
|
808
|
+
size_t k1 = (bucket + 1) << (log2_capacity - log2_nbucket);
|
809
|
+
for (;;) {
|
810
|
+
if (tab[slot * 2] == k) { // found!
|
811
|
+
vals[i] = tab[2 * slot + 1];
|
812
|
+
break;
|
813
|
+
}
|
814
|
+
slot++;
|
815
|
+
if (slot == k1) {
|
816
|
+
slot = k0;
|
817
|
+
}
|
818
|
+
if (slot == hk) { // bucket is full and not found
|
819
|
+
vals[i] = -1;
|
820
|
+
break;
|
821
|
+
}
|
822
|
+
}
|
823
|
+
}
|
824
|
+
}
|
825
|
+
}
|
826
|
+
|
692
827
|
} // namespace faiss
|
@@ -68,4 +68,31 @@ void matrix_bucket_sort_inplace(
|
|
68
68
|
int64_t* lims,
|
69
69
|
int nt = 0);
|
70
70
|
|
71
|
+
/** Hashtable implementation for int64 -> int64 with external storage
|
72
|
+
* implemented for fast batch add and lookup.
|
73
|
+
*
|
74
|
+
* tab is of size 2 * (1 << log2_capacity)
|
75
|
+
* n is the number of elements to add or search
|
76
|
+
*
|
77
|
+
* adding several values in a same batch: an arbitrary one gets added
|
78
|
+
* in different batches: the newer batch overwrites.
|
79
|
+
* raises an exception if capacity is exhausted.
|
80
|
+
*/
|
81
|
+
|
82
|
+
void hashtable_int64_to_int64_init(int log2_capacity, int64_t* tab);
|
83
|
+
|
84
|
+
void hashtable_int64_to_int64_add(
|
85
|
+
int log2_capacity,
|
86
|
+
int64_t* tab,
|
87
|
+
size_t n,
|
88
|
+
const int64_t* keys,
|
89
|
+
const int64_t* vals);
|
90
|
+
|
91
|
+
void hashtable_int64_to_int64_lookup(
|
92
|
+
int log2_capacity,
|
93
|
+
const int64_t* tab,
|
94
|
+
size_t n,
|
95
|
+
const int64_t* keys,
|
96
|
+
int64_t* vals);
|
97
|
+
|
71
98
|
} // namespace faiss
|
@@ -28,6 +28,8 @@
|
|
28
28
|
#include <omp.h>
|
29
29
|
|
30
30
|
#include <algorithm>
|
31
|
+
#include <set>
|
32
|
+
#include <type_traits>
|
31
33
|
#include <vector>
|
32
34
|
|
33
35
|
#include <faiss/impl/AuxIndexStructures.h>
|
@@ -101,6 +103,9 @@ int sgemv_(
|
|
101
103
|
|
102
104
|
namespace faiss {
|
103
105
|
|
106
|
+
// this will be set at load time from GPU Faiss
|
107
|
+
std::string gpu_compile_options;
|
108
|
+
|
104
109
|
std::string get_compile_options() {
|
105
110
|
std::string options;
|
106
111
|
|
@@ -110,13 +115,17 @@ std::string get_compile_options() {
|
|
110
115
|
#endif
|
111
116
|
|
112
117
|
#ifdef __AVX2__
|
113
|
-
options += "AVX2";
|
118
|
+
options += "AVX2 ";
|
119
|
+
#elif __AVX512F__
|
120
|
+
options += "AVX512 ";
|
114
121
|
#elif defined(__aarch64__)
|
115
|
-
options += "NEON";
|
122
|
+
options += "NEON ";
|
116
123
|
#else
|
117
|
-
options += "GENERIC";
|
124
|
+
options += "GENERIC ";
|
118
125
|
#endif
|
119
126
|
|
127
|
+
options += gpu_compile_options;
|
128
|
+
|
120
129
|
return options;
|
121
130
|
}
|
122
131
|
|
@@ -423,15 +432,35 @@ void bincode_hist(size_t n, size_t nbits, const uint8_t* codes, int* hist) {
|
|
423
432
|
}
|
424
433
|
}
|
425
434
|
|
426
|
-
|
427
|
-
const uint32_t* a = reinterpret_cast<const uint32_t*>(
|
428
|
-
|
435
|
+
uint64_t ivec_checksum(size_t n, const int32_t* assigned) {
|
436
|
+
const uint32_t* a = reinterpret_cast<const uint32_t*>(assigned);
|
437
|
+
uint64_t cs = 112909;
|
429
438
|
while (n--) {
|
430
439
|
cs = cs * 65713 + a[n] * 1686049;
|
431
440
|
}
|
432
441
|
return cs;
|
433
442
|
}
|
434
443
|
|
444
|
+
uint64_t bvec_checksum(size_t n, const uint8_t* a) {
|
445
|
+
uint64_t cs = ivec_checksum(n / 4, (const int32_t*)a);
|
446
|
+
for (size_t i = n / 4 * 4; i < n; i++) {
|
447
|
+
cs = cs * 65713 + a[n] * 1686049;
|
448
|
+
}
|
449
|
+
return cs;
|
450
|
+
}
|
451
|
+
|
452
|
+
void bvecs_checksum(size_t n, size_t d, const uint8_t* a, uint64_t* cs) {
|
453
|
+
// MSVC can't accept unsigned index for #pragma omp parallel for
|
454
|
+
// so below codes only accept n <= std::numeric_limits<ssize_t>::max()
|
455
|
+
using ssize_t = std::make_signed<std::size_t>::type;
|
456
|
+
const ssize_t size = n;
|
457
|
+
#pragma omp parallel for if (size > 1000)
|
458
|
+
for (ssize_t i_ = 0; i_ < size; i_++) {
|
459
|
+
const auto i = static_cast<std::size_t>(i_);
|
460
|
+
cs[i] = bvec_checksum(d, a + i * d);
|
461
|
+
}
|
462
|
+
}
|
463
|
+
|
435
464
|
const float* fvecs_maybe_subsample(
|
436
465
|
size_t d,
|
437
466
|
size_t* n,
|
@@ -528,4 +557,81 @@ bool check_openmp() {
|
|
528
557
|
return true;
|
529
558
|
}
|
530
559
|
|
560
|
+
namespace {
|
561
|
+
|
562
|
+
template <typename T>
|
563
|
+
int64_t count_lt(int64_t n, const T* row, T threshold) {
|
564
|
+
for (int64_t i = 0; i < n; i++) {
|
565
|
+
if (!(row[i] < threshold)) {
|
566
|
+
return i;
|
567
|
+
}
|
568
|
+
}
|
569
|
+
return n;
|
570
|
+
}
|
571
|
+
|
572
|
+
template <typename T>
|
573
|
+
int64_t count_gt(int64_t n, const T* row, T threshold) {
|
574
|
+
for (int64_t i = 0; i < n; i++) {
|
575
|
+
if (!(row[i] > threshold)) {
|
576
|
+
return i;
|
577
|
+
}
|
578
|
+
}
|
579
|
+
return n;
|
580
|
+
}
|
581
|
+
|
582
|
+
} // namespace
|
583
|
+
|
584
|
+
template <typename T>
|
585
|
+
void CombinerRangeKNN<T>::compute_sizes(int64_t* L_res_2) {
|
586
|
+
this->L_res = L_res_2;
|
587
|
+
L_res_2[0] = 0;
|
588
|
+
int64_t j = 0;
|
589
|
+
for (int64_t i = 0; i < nq; i++) {
|
590
|
+
int64_t n_in;
|
591
|
+
if (!mask || !mask[i]) {
|
592
|
+
const T* row = D + i * k;
|
593
|
+
n_in = keep_max ? count_gt(k, row, r2) : count_lt(k, row, r2);
|
594
|
+
} else {
|
595
|
+
n_in = lim_remain[j + 1] - lim_remain[j];
|
596
|
+
j++;
|
597
|
+
}
|
598
|
+
L_res_2[i + 1] = n_in; // L_res_2[i] + n_in;
|
599
|
+
}
|
600
|
+
// cumsum
|
601
|
+
for (int64_t i = 0; i < nq; i++) {
|
602
|
+
L_res_2[i + 1] += L_res_2[i];
|
603
|
+
}
|
604
|
+
}
|
605
|
+
|
606
|
+
template <typename T>
|
607
|
+
void CombinerRangeKNN<T>::write_result(T* D_res, int64_t* I_res) {
|
608
|
+
FAISS_THROW_IF_NOT(L_res);
|
609
|
+
int64_t j = 0;
|
610
|
+
for (int64_t i = 0; i < nq; i++) {
|
611
|
+
int64_t n_in = L_res[i + 1] - L_res[i];
|
612
|
+
T* D_row = D_res + L_res[i];
|
613
|
+
int64_t* I_row = I_res + L_res[i];
|
614
|
+
if (!mask || !mask[i]) {
|
615
|
+
memcpy(D_row, D + i * k, n_in * sizeof(*D_row));
|
616
|
+
memcpy(I_row, I + i * k, n_in * sizeof(*I_row));
|
617
|
+
} else {
|
618
|
+
memcpy(D_row, D_remain + lim_remain[j], n_in * sizeof(*D_row));
|
619
|
+
memcpy(I_row, I_remain + lim_remain[j], n_in * sizeof(*I_row));
|
620
|
+
j++;
|
621
|
+
}
|
622
|
+
}
|
623
|
+
}
|
624
|
+
|
625
|
+
// explicit template instantiations
|
626
|
+
template struct CombinerRangeKNN<float>;
|
627
|
+
template struct CombinerRangeKNN<int16_t>;
|
628
|
+
|
629
|
+
void CodeSet::insert(size_t n, const uint8_t* codes, bool* inserted) {
|
630
|
+
for (size_t i = 0; i < n; i++) {
|
631
|
+
auto res = s.insert(
|
632
|
+
std::vector<uint8_t>(codes + i * d, codes + i * d + d));
|
633
|
+
inserted[i] = res.second;
|
634
|
+
}
|
635
|
+
}
|
636
|
+
|
531
637
|
} // namespace faiss
|