faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -559,15 +559,13 @@ struct simd16uint16 {
559
559
  }
560
560
 
561
561
  // Checks whether the other holds exactly the same bytes.
562
- bool is_same_as(simd16uint16 other) const {
563
- const bool equal0 =
564
- (vminvq_u16(vceqq_u16(data.val[0], other.data.val[0])) ==
565
- 0xffff);
566
- const bool equal1 =
567
- (vminvq_u16(vceqq_u16(data.val[1], other.data.val[1])) ==
568
- 0xffff);
569
-
570
- return equal0 && equal1;
562
+ template <typename T>
563
+ bool is_same_as(T other) const {
564
+ const auto o = detail::simdlib::reinterpret_u16(other.data);
565
+ const auto equals = detail::simdlib::binary_func(data, o)
566
+ .template call<&vceqq_u16>();
567
+ const auto equal = vandq_u16(equals.val[0], equals.val[1]);
568
+ return vminvq_u16(equal) == 0xffffu;
571
569
  }
572
570
 
573
571
  simd16uint16 operator~() const {
@@ -689,13 +687,12 @@ inline void cmplt_min_max_fast(
689
687
  simd16uint16& minIndices,
690
688
  simd16uint16& maxValues,
691
689
  simd16uint16& maxIndices) {
692
- const uint16x8x2_t comparison = uint16x8x2_t{
693
- vcltq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
694
- vcltq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
690
+ const uint16x8x2_t comparison =
691
+ detail::simdlib::binary_func(
692
+ candidateValues.data, currentValues.data)
693
+ .call<&vcltq_u16>();
695
694
 
696
- minValues.data = uint16x8x2_t{
697
- vminq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
698
- vminq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
695
+ minValues = min(candidateValues, currentValues);
699
696
  minIndices.data = uint16x8x2_t{
700
697
  vbslq_u16(
701
698
  comparison.val[0],
@@ -706,9 +703,7 @@ inline void cmplt_min_max_fast(
706
703
  candidateIndices.data.val[1],
707
704
  currentIndices.data.val[1])};
708
705
 
709
- maxValues.data = uint16x8x2_t{
710
- vmaxq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
711
- vmaxq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
706
+ maxValues = max(candidateValues, currentValues);
712
707
  maxIndices.data = uint16x8x2_t{
713
708
  vbslq_u16(
714
709
  comparison.val[0],
@@ -869,13 +864,13 @@ struct simd32uint8 {
869
864
  }
870
865
 
871
866
  // Checks whether the other holds exactly the same bytes.
872
- bool is_same_as(simd32uint8 other) const {
873
- const bool equal0 =
874
- (vminvq_u8(vceqq_u8(data.val[0], other.data.val[0])) == 0xff);
875
- const bool equal1 =
876
- (vminvq_u8(vceqq_u8(data.val[1], other.data.val[1])) == 0xff);
877
-
878
- return equal0 && equal1;
867
+ template <typename T>
868
+ bool is_same_as(T other) const {
869
+ const auto o = detail::simdlib::reinterpret_u8(other.data);
870
+ const auto equals = detail::simdlib::binary_func(data, o)
871
+ .template call<&vceqq_u8>();
872
+ const auto equal = vandq_u8(equals.val[0], equals.val[1]);
873
+ return vminvq_u8(equal) == 0xffu;
879
874
  }
880
875
  };
881
876
 
@@ -960,27 +955,28 @@ struct simd8uint32 {
960
955
  return *this;
961
956
  }
962
957
 
963
- bool operator==(simd8uint32 other) const {
964
- const auto equals = detail::simdlib::binary_func(data, other.data)
965
- .call<&vceqq_u32>();
966
- const auto equal = vandq_u32(equals.val[0], equals.val[1]);
967
- return vminvq_u32(equal) == 0xffffffff;
958
+ simd8uint32 operator==(simd8uint32 other) const {
959
+ return simd8uint32{detail::simdlib::binary_func(data, other.data)
960
+ .call<&vceqq_u32>()};
968
961
  }
969
962
 
970
- bool operator!=(simd8uint32 other) const {
971
- return !(*this == other);
963
+ simd8uint32 operator~() const {
964
+ return simd8uint32{
965
+ detail::simdlib::unary_func(data).call<&vmvnq_u32>()};
972
966
  }
973
967
 
974
- // Checks whether the other holds exactly the same bytes.
975
- bool is_same_as(simd8uint32 other) const {
976
- const bool equal0 =
977
- (vminvq_u32(vceqq_u32(data.val[0], other.data.val[0])) ==
978
- 0xffffffff);
979
- const bool equal1 =
980
- (vminvq_u32(vceqq_u32(data.val[1], other.data.val[1])) ==
981
- 0xffffffff);
968
+ simd8uint32 operator!=(simd8uint32 other) const {
969
+ return ~(*this == other);
970
+ }
982
971
 
983
- return equal0 && equal1;
972
+ // Checks whether the other holds exactly the same bytes.
973
+ template <typename T>
974
+ bool is_same_as(T other) const {
975
+ const auto o = detail::simdlib::reinterpret_u32(other.data);
976
+ const auto equals = detail::simdlib::binary_func(data, o)
977
+ .template call<&vceqq_u32>();
978
+ const auto equal = vandq_u32(equals.val[0], equals.val[1]);
979
+ return vminvq_u32(equal) == 0xffffffffu;
984
980
  }
985
981
 
986
982
  void clear() {
@@ -1053,13 +1049,14 @@ inline void cmplt_min_max_fast(
1053
1049
  simd8uint32& minIndices,
1054
1050
  simd8uint32& maxValues,
1055
1051
  simd8uint32& maxIndices) {
1056
- const uint32x4x2_t comparison = uint32x4x2_t{
1057
- vcltq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1058
- vcltq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1059
-
1060
- minValues.data = uint32x4x2_t{
1061
- vminq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1062
- vminq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1052
+ const uint32x4x2_t comparison =
1053
+ detail::simdlib::binary_func(
1054
+ candidateValues.data, currentValues.data)
1055
+ .call<&vcltq_u32>();
1056
+
1057
+ minValues.data = detail::simdlib::binary_func(
1058
+ candidateValues.data, currentValues.data)
1059
+ .call<&vminq_u32>();
1063
1060
  minIndices.data = uint32x4x2_t{
1064
1061
  vbslq_u32(
1065
1062
  comparison.val[0],
@@ -1070,9 +1067,9 @@ inline void cmplt_min_max_fast(
1070
1067
  candidateIndices.data.val[1],
1071
1068
  currentIndices.data.val[1])};
1072
1069
 
1073
- maxValues.data = uint32x4x2_t{
1074
- vmaxq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1075
- vmaxq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1070
+ maxValues.data = detail::simdlib::binary_func(
1071
+ candidateValues.data, currentValues.data)
1072
+ .call<&vmaxq_u32>();
1076
1073
  maxIndices.data = uint32x4x2_t{
1077
1074
  vbslq_u32(
1078
1075
  comparison.val[0],
@@ -1167,28 +1164,25 @@ struct simd8float32 {
1167
1164
  return *this;
1168
1165
  }
1169
1166
 
1170
- bool operator==(simd8float32 other) const {
1171
- const auto equals =
1167
+ simd8uint32 operator==(simd8float32 other) const {
1168
+ return simd8uint32{
1172
1169
  detail::simdlib::binary_func<::uint32x4x2_t>(data, other.data)
1173
- .call<&vceqq_f32>();
1174
- const auto equal = vandq_u32(equals.val[0], equals.val[1]);
1175
- return vminvq_u32(equal) == 0xffffffff;
1170
+ .call<&vceqq_f32>()};
1176
1171
  }
1177
1172
 
1178
- bool operator!=(simd8float32 other) const {
1179
- return !(*this == other);
1173
+ simd8uint32 operator!=(simd8float32 other) const {
1174
+ return ~(*this == other);
1180
1175
  }
1181
1176
 
1182
1177
  // Checks whether the other holds exactly the same bytes.
1183
- bool is_same_as(simd8float32 other) const {
1184
- const bool equal0 =
1185
- (vminvq_u32(vceqq_f32(data.val[0], other.data.val[0])) ==
1186
- 0xffffffff);
1187
- const bool equal1 =
1188
- (vminvq_u32(vceqq_f32(data.val[1], other.data.val[1])) ==
1189
- 0xffffffff);
1190
-
1191
- return equal0 && equal1;
1178
+ template <typename T>
1179
+ bool is_same_as(T other) const {
1180
+ const auto o = detail::simdlib::reinterpret_f32(other.data);
1181
+ const auto equals =
1182
+ detail::simdlib::binary_func<::uint32x4x2_t>(data, o)
1183
+ .template call<&vceqq_f32>();
1184
+ const auto equal = vandq_u32(equals.val[0], equals.val[1]);
1185
+ return vminvq_u32(equal) == 0xffffffffu;
1192
1186
  }
1193
1187
 
1194
1188
  std::string tostring() const {
@@ -1302,13 +1296,14 @@ inline void cmplt_min_max_fast(
1302
1296
  simd8uint32& minIndices,
1303
1297
  simd8float32& maxValues,
1304
1298
  simd8uint32& maxIndices) {
1305
- const uint32x4x2_t comparison = uint32x4x2_t{
1306
- vcltq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1307
- vcltq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1308
-
1309
- minValues.data = float32x4x2_t{
1310
- vminq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1311
- vminq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1299
+ const uint32x4x2_t comparison =
1300
+ detail::simdlib::binary_func<::uint32x4x2_t>(
1301
+ candidateValues.data, currentValues.data)
1302
+ .call<&vcltq_f32>();
1303
+
1304
+ minValues.data = detail::simdlib::binary_func(
1305
+ candidateValues.data, currentValues.data)
1306
+ .call<&vminq_f32>();
1312
1307
  minIndices.data = uint32x4x2_t{
1313
1308
  vbslq_u32(
1314
1309
  comparison.val[0],
@@ -1319,9 +1314,9 @@ inline void cmplt_min_max_fast(
1319
1314
  candidateIndices.data.val[1],
1320
1315
  currentIndices.data.val[1])};
1321
1316
 
1322
- maxValues.data = float32x4x2_t{
1323
- vmaxq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1324
- vmaxq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1317
+ maxValues.data = detail::simdlib::binary_func(
1318
+ candidateValues.data, currentValues.data)
1319
+ .call<&vmaxq_f32>();
1325
1320
  maxIndices.data = uint32x4x2_t{
1326
1321
  vbslq_u32(
1327
1322
  comparison.val[0],
@@ -123,7 +123,7 @@ void parallel_merge(
123
123
  }
124
124
  }
125
125
 
126
- }; // namespace
126
+ } // namespace
127
127
 
128
128
  void fvec_argsort(size_t n, const float* vals, size_t* perm) {
129
129
  for (size_t i = 0; i < n; i++) {
@@ -544,7 +544,6 @@ void bucket_sort_inplace_parallel(
544
544
 
545
545
  // in this loop, we write elements collected in the previous round
546
546
  // and collect the elements that are overwritten for the next round
547
- size_t tot_written = 0;
548
547
  int round = 0;
549
548
  for (;;) {
550
549
  #pragma omp barrier
@@ -554,9 +553,6 @@ void bucket_sort_inplace_parallel(
554
553
  n_to_write += to_write_2.lims.back();
555
554
  }
556
555
 
557
- tot_written += n_to_write;
558
- // assert(tot_written <= nval);
559
-
560
556
  #pragma omp master
561
557
  {
562
558
  if (verbose >= 1) {
@@ -689,4 +685,143 @@ void matrix_bucket_sort_inplace(
689
685
  }
690
686
  }
691
687
 
688
+ /** Hashtable implementation for int64 -> int64 with external storage
689
+ * implemented for speed and parallel processing.
690
+ */
691
+
692
+ namespace {
693
+
694
+ int log2_capacity_to_log2_nbucket(int log2_capacity) {
695
+ return log2_capacity < 12 ? 0
696
+ : log2_capacity < 20 ? log2_capacity - 12
697
+ : 10;
698
+ }
699
+
700
+ // https://bigprimes.org/
701
+ int64_t bigprime = 8955327411143;
702
+
703
+ inline int64_t hash_function(int64_t x) {
704
+ return (x * 1000003) % bigprime;
705
+ }
706
+
707
+ } // anonymous namespace
708
+
709
+ void hashtable_int64_to_int64_init(int log2_capacity, int64_t* tab) {
710
+ size_t capacity = (size_t)1 << log2_capacity;
711
+ #pragma omp parallel for
712
+ for (int64_t i = 0; i < capacity; i++) {
713
+ tab[2 * i] = -1;
714
+ tab[2 * i + 1] = -1;
715
+ }
716
+ }
717
+
718
+ void hashtable_int64_to_int64_add(
719
+ int log2_capacity,
720
+ int64_t* tab,
721
+ size_t n,
722
+ const int64_t* keys,
723
+ const int64_t* vals) {
724
+ size_t capacity = (size_t)1 << log2_capacity;
725
+ std::vector<int64_t> hk(n);
726
+ std::vector<uint64_t> bucket_no(n);
727
+ int64_t mask = capacity - 1;
728
+ int log2_nbucket = log2_capacity_to_log2_nbucket(log2_capacity);
729
+ size_t nbucket = (size_t)1 << log2_nbucket;
730
+
731
+ #pragma omp parallel for
732
+ for (int64_t i = 0; i < n; i++) {
733
+ hk[i] = hash_function(keys[i]) & mask;
734
+ bucket_no[i] = hk[i] >> (log2_capacity - log2_nbucket);
735
+ }
736
+
737
+ std::vector<int64_t> lims(nbucket + 1);
738
+ std::vector<int64_t> perm(n);
739
+ bucket_sort(
740
+ n,
741
+ bucket_no.data(),
742
+ nbucket,
743
+ lims.data(),
744
+ perm.data(),
745
+ omp_get_max_threads());
746
+
747
+ int num_errors = 0;
748
+ #pragma omp parallel for reduction(+ : num_errors)
749
+ for (int64_t bucket = 0; bucket < nbucket; bucket++) {
750
+ size_t k0 = bucket << (log2_capacity - log2_nbucket);
751
+ size_t k1 = (bucket + 1) << (log2_capacity - log2_nbucket);
752
+
753
+ for (size_t i = lims[bucket]; i < lims[bucket + 1]; i++) {
754
+ int64_t j = perm[i];
755
+ assert(bucket_no[j] == bucket);
756
+ assert(hk[j] >= k0 && hk[j] < k1);
757
+ size_t slot = hk[j];
758
+ for (;;) {
759
+ if (tab[slot * 2] == -1) { // found!
760
+ tab[slot * 2] = keys[j];
761
+ tab[slot * 2 + 1] = vals[j];
762
+ break;
763
+ } else if (tab[slot * 2] == keys[j]) { // overwrite!
764
+ tab[slot * 2 + 1] = vals[j];
765
+ break;
766
+ }
767
+ slot++;
768
+ if (slot == k1) {
769
+ slot = k0;
770
+ }
771
+ if (slot == hk[j]) { // no free slot left in bucket
772
+ num_errors++;
773
+ break;
774
+ }
775
+ }
776
+ if (num_errors > 0) {
777
+ break;
778
+ }
779
+ }
780
+ }
781
+ FAISS_THROW_IF_NOT_MSG(num_errors == 0, "hashtable capacity exhausted");
782
+ }
783
+
784
+ void hashtable_int64_to_int64_lookup(
785
+ int log2_capacity,
786
+ const int64_t* tab,
787
+ size_t n,
788
+ const int64_t* keys,
789
+ int64_t* vals) {
790
+ size_t capacity = (size_t)1 << log2_capacity;
791
+ std::vector<int64_t> hk(n), bucket_no(n);
792
+ int64_t mask = capacity - 1;
793
+ int log2_nbucket = log2_capacity_to_log2_nbucket(log2_capacity);
794
+
795
+ #pragma omp parallel for
796
+ for (int64_t i = 0; i < n; i++) {
797
+ int64_t k = keys[i];
798
+ int64_t hk = hash_function(k) & mask;
799
+ size_t slot = hk;
800
+
801
+ if (tab[2 * slot] == -1) { // not in table
802
+ vals[i] = -1;
803
+ } else if (tab[2 * slot] == k) { // found!
804
+ vals[i] = tab[2 * slot + 1];
805
+ } else { // need to search in [k0, k1)
806
+ size_t bucket = hk >> (log2_capacity - log2_nbucket);
807
+ size_t k0 = bucket << (log2_capacity - log2_nbucket);
808
+ size_t k1 = (bucket + 1) << (log2_capacity - log2_nbucket);
809
+ for (;;) {
810
+ if (tab[slot * 2] == k) { // found!
811
+ vals[i] = tab[2 * slot + 1];
812
+ break;
813
+ }
814
+ slot++;
815
+ if (slot == k1) {
816
+ slot = k0;
817
+ }
818
+ if (slot == hk) { // bucket is full and not found
819
+ vals[i] = -1;
820
+ break;
821
+ }
822
+ }
823
+ }
824
+ }
825
+ }
826
+
692
827
  } // namespace faiss
@@ -68,4 +68,31 @@ void matrix_bucket_sort_inplace(
68
68
  int64_t* lims,
69
69
  int nt = 0);
70
70
 
71
+ /** Hashtable implementation for int64 -> int64 with external storage
72
+ * implemented for fast batch add and lookup.
73
+ *
74
+ * tab is of size 2 * (1 << log2_capacity)
75
+ * n is the number of elements to add or search
76
+ *
77
+ * adding several values in a same batch: an arbitrary one gets added
78
+ * in different batches: the newer batch overwrites.
79
+ * raises an exception if capacity is exhausted.
80
+ */
81
+
82
+ void hashtable_int64_to_int64_init(int log2_capacity, int64_t* tab);
83
+
84
+ void hashtable_int64_to_int64_add(
85
+ int log2_capacity,
86
+ int64_t* tab,
87
+ size_t n,
88
+ const int64_t* keys,
89
+ const int64_t* vals);
90
+
91
+ void hashtable_int64_to_int64_lookup(
92
+ int log2_capacity,
93
+ const int64_t* tab,
94
+ size_t n,
95
+ const int64_t* keys,
96
+ int64_t* vals);
97
+
71
98
  } // namespace faiss
@@ -28,6 +28,8 @@
28
28
  #include <omp.h>
29
29
 
30
30
  #include <algorithm>
31
+ #include <set>
32
+ #include <type_traits>
31
33
  #include <vector>
32
34
 
33
35
  #include <faiss/impl/AuxIndexStructures.h>
@@ -101,6 +103,9 @@ int sgemv_(
101
103
 
102
104
  namespace faiss {
103
105
 
106
+ // this will be set at load time from GPU Faiss
107
+ std::string gpu_compile_options;
108
+
104
109
  std::string get_compile_options() {
105
110
  std::string options;
106
111
 
@@ -110,13 +115,17 @@ std::string get_compile_options() {
110
115
  #endif
111
116
 
112
117
  #ifdef __AVX2__
113
- options += "AVX2";
118
+ options += "AVX2 ";
119
+ #elif __AVX512F__
120
+ options += "AVX512 ";
114
121
  #elif defined(__aarch64__)
115
- options += "NEON";
122
+ options += "NEON ";
116
123
  #else
117
- options += "GENERIC";
124
+ options += "GENERIC ";
118
125
  #endif
119
126
 
127
+ options += gpu_compile_options;
128
+
120
129
  return options;
121
130
  }
122
131
 
@@ -423,15 +432,35 @@ void bincode_hist(size_t n, size_t nbits, const uint8_t* codes, int* hist) {
423
432
  }
424
433
  }
425
434
 
426
- size_t ivec_checksum(size_t n, const int32_t* asigned) {
427
- const uint32_t* a = reinterpret_cast<const uint32_t*>(asigned);
428
- size_t cs = 112909;
435
+ uint64_t ivec_checksum(size_t n, const int32_t* assigned) {
436
+ const uint32_t* a = reinterpret_cast<const uint32_t*>(assigned);
437
+ uint64_t cs = 112909;
429
438
  while (n--) {
430
439
  cs = cs * 65713 + a[n] * 1686049;
431
440
  }
432
441
  return cs;
433
442
  }
434
443
 
444
+ uint64_t bvec_checksum(size_t n, const uint8_t* a) {
445
+ uint64_t cs = ivec_checksum(n / 4, (const int32_t*)a);
446
+ for (size_t i = n / 4 * 4; i < n; i++) {
447
+ cs = cs * 65713 + a[n] * 1686049;
448
+ }
449
+ return cs;
450
+ }
451
+
452
+ void bvecs_checksum(size_t n, size_t d, const uint8_t* a, uint64_t* cs) {
453
+ // MSVC can't accept unsigned index for #pragma omp parallel for
454
+ // so below codes only accept n <= std::numeric_limits<ssize_t>::max()
455
+ using ssize_t = std::make_signed<std::size_t>::type;
456
+ const ssize_t size = n;
457
+ #pragma omp parallel for if (size > 1000)
458
+ for (ssize_t i_ = 0; i_ < size; i_++) {
459
+ const auto i = static_cast<std::size_t>(i_);
460
+ cs[i] = bvec_checksum(d, a + i * d);
461
+ }
462
+ }
463
+
435
464
  const float* fvecs_maybe_subsample(
436
465
  size_t d,
437
466
  size_t* n,
@@ -528,4 +557,81 @@ bool check_openmp() {
528
557
  return true;
529
558
  }
530
559
 
560
+ namespace {
561
+
562
+ template <typename T>
563
+ int64_t count_lt(int64_t n, const T* row, T threshold) {
564
+ for (int64_t i = 0; i < n; i++) {
565
+ if (!(row[i] < threshold)) {
566
+ return i;
567
+ }
568
+ }
569
+ return n;
570
+ }
571
+
572
+ template <typename T>
573
+ int64_t count_gt(int64_t n, const T* row, T threshold) {
574
+ for (int64_t i = 0; i < n; i++) {
575
+ if (!(row[i] > threshold)) {
576
+ return i;
577
+ }
578
+ }
579
+ return n;
580
+ }
581
+
582
+ } // namespace
583
+
584
+ template <typename T>
585
+ void CombinerRangeKNN<T>::compute_sizes(int64_t* L_res_2) {
586
+ this->L_res = L_res_2;
587
+ L_res_2[0] = 0;
588
+ int64_t j = 0;
589
+ for (int64_t i = 0; i < nq; i++) {
590
+ int64_t n_in;
591
+ if (!mask || !mask[i]) {
592
+ const T* row = D + i * k;
593
+ n_in = keep_max ? count_gt(k, row, r2) : count_lt(k, row, r2);
594
+ } else {
595
+ n_in = lim_remain[j + 1] - lim_remain[j];
596
+ j++;
597
+ }
598
+ L_res_2[i + 1] = n_in; // L_res_2[i] + n_in;
599
+ }
600
+ // cumsum
601
+ for (int64_t i = 0; i < nq; i++) {
602
+ L_res_2[i + 1] += L_res_2[i];
603
+ }
604
+ }
605
+
606
+ template <typename T>
607
+ void CombinerRangeKNN<T>::write_result(T* D_res, int64_t* I_res) {
608
+ FAISS_THROW_IF_NOT(L_res);
609
+ int64_t j = 0;
610
+ for (int64_t i = 0; i < nq; i++) {
611
+ int64_t n_in = L_res[i + 1] - L_res[i];
612
+ T* D_row = D_res + L_res[i];
613
+ int64_t* I_row = I_res + L_res[i];
614
+ if (!mask || !mask[i]) {
615
+ memcpy(D_row, D + i * k, n_in * sizeof(*D_row));
616
+ memcpy(I_row, I + i * k, n_in * sizeof(*I_row));
617
+ } else {
618
+ memcpy(D_row, D_remain + lim_remain[j], n_in * sizeof(*D_row));
619
+ memcpy(I_row, I_remain + lim_remain[j], n_in * sizeof(*I_row));
620
+ j++;
621
+ }
622
+ }
623
+ }
624
+
625
+ // explicit template instantiations
626
+ template struct CombinerRangeKNN<float>;
627
+ template struct CombinerRangeKNN<int16_t>;
628
+
629
+ void CodeSet::insert(size_t n, const uint8_t* codes, bool* inserted) {
630
+ for (size_t i = 0; i < n; i++) {
631
+ auto res = s.insert(
632
+ std::vector<uint8_t>(codes + i * d, codes + i * d + d));
633
+ inserted[i] = res.second;
634
+ }
635
+ }
636
+
531
637
  } // namespace faiss