faiss 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -559,15 +559,13 @@ struct simd16uint16 {
559
559
  }
560
560
 
561
561
  // Checks whether the other holds exactly the same bytes.
562
- bool is_same_as(simd16uint16 other) const {
563
- const bool equal0 =
564
- (vminvq_u16(vceqq_u16(data.val[0], other.data.val[0])) ==
565
- 0xffff);
566
- const bool equal1 =
567
- (vminvq_u16(vceqq_u16(data.val[1], other.data.val[1])) ==
568
- 0xffff);
569
-
570
- return equal0 && equal1;
562
+ template <typename T>
563
+ bool is_same_as(T other) const {
564
+ const auto o = detail::simdlib::reinterpret_u16(other.data);
565
+ const auto equals = detail::simdlib::binary_func(data, o)
566
+ .template call<&vceqq_u16>();
567
+ const auto equal = vandq_u16(equals.val[0], equals.val[1]);
568
+ return vminvq_u16(equal) == 0xffffu;
571
569
  }
572
570
 
573
571
  simd16uint16 operator~() const {
@@ -689,13 +687,12 @@ inline void cmplt_min_max_fast(
689
687
  simd16uint16& minIndices,
690
688
  simd16uint16& maxValues,
691
689
  simd16uint16& maxIndices) {
692
- const uint16x8x2_t comparison = uint16x8x2_t{
693
- vcltq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
694
- vcltq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
690
+ const uint16x8x2_t comparison =
691
+ detail::simdlib::binary_func(
692
+ candidateValues.data, currentValues.data)
693
+ .call<&vcltq_u16>();
695
694
 
696
- minValues.data = uint16x8x2_t{
697
- vminq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
698
- vminq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
695
+ minValues = min(candidateValues, currentValues);
699
696
  minIndices.data = uint16x8x2_t{
700
697
  vbslq_u16(
701
698
  comparison.val[0],
@@ -706,9 +703,7 @@ inline void cmplt_min_max_fast(
706
703
  candidateIndices.data.val[1],
707
704
  currentIndices.data.val[1])};
708
705
 
709
- maxValues.data = uint16x8x2_t{
710
- vmaxq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
711
- vmaxq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
706
+ maxValues = max(candidateValues, currentValues);
712
707
  maxIndices.data = uint16x8x2_t{
713
708
  vbslq_u16(
714
709
  comparison.val[0],
@@ -869,13 +864,13 @@ struct simd32uint8 {
869
864
  }
870
865
 
871
866
  // Checks whether the other holds exactly the same bytes.
872
- bool is_same_as(simd32uint8 other) const {
873
- const bool equal0 =
874
- (vminvq_u8(vceqq_u8(data.val[0], other.data.val[0])) == 0xff);
875
- const bool equal1 =
876
- (vminvq_u8(vceqq_u8(data.val[1], other.data.val[1])) == 0xff);
877
-
878
- return equal0 && equal1;
867
+ template <typename T>
868
+ bool is_same_as(T other) const {
869
+ const auto o = detail::simdlib::reinterpret_u8(other.data);
870
+ const auto equals = detail::simdlib::binary_func(data, o)
871
+ .template call<&vceqq_u8>();
872
+ const auto equal = vandq_u8(equals.val[0], equals.val[1]);
873
+ return vminvq_u8(equal) == 0xffu;
879
874
  }
880
875
  };
881
876
 
@@ -960,27 +955,28 @@ struct simd8uint32 {
960
955
  return *this;
961
956
  }
962
957
 
963
- bool operator==(simd8uint32 other) const {
964
- const auto equals = detail::simdlib::binary_func(data, other.data)
965
- .call<&vceqq_u32>();
966
- const auto equal = vandq_u32(equals.val[0], equals.val[1]);
967
- return vminvq_u32(equal) == 0xffffffff;
958
+ simd8uint32 operator==(simd8uint32 other) const {
959
+ return simd8uint32{detail::simdlib::binary_func(data, other.data)
960
+ .call<&vceqq_u32>()};
968
961
  }
969
962
 
970
- bool operator!=(simd8uint32 other) const {
971
- return !(*this == other);
963
+ simd8uint32 operator~() const {
964
+ return simd8uint32{
965
+ detail::simdlib::unary_func(data).call<&vmvnq_u32>()};
972
966
  }
973
967
 
974
- // Checks whether the other holds exactly the same bytes.
975
- bool is_same_as(simd8uint32 other) const {
976
- const bool equal0 =
977
- (vminvq_u32(vceqq_u32(data.val[0], other.data.val[0])) ==
978
- 0xffffffff);
979
- const bool equal1 =
980
- (vminvq_u32(vceqq_u32(data.val[1], other.data.val[1])) ==
981
- 0xffffffff);
968
+ simd8uint32 operator!=(simd8uint32 other) const {
969
+ return ~(*this == other);
970
+ }
982
971
 
983
- return equal0 && equal1;
972
+ // Checks whether the other holds exactly the same bytes.
973
+ template <typename T>
974
+ bool is_same_as(T other) const {
975
+ const auto o = detail::simdlib::reinterpret_u32(other.data);
976
+ const auto equals = detail::simdlib::binary_func(data, o)
977
+ .template call<&vceqq_u32>();
978
+ const auto equal = vandq_u32(equals.val[0], equals.val[1]);
979
+ return vminvq_u32(equal) == 0xffffffffu;
984
980
  }
985
981
 
986
982
  void clear() {
@@ -1053,13 +1049,14 @@ inline void cmplt_min_max_fast(
1053
1049
  simd8uint32& minIndices,
1054
1050
  simd8uint32& maxValues,
1055
1051
  simd8uint32& maxIndices) {
1056
- const uint32x4x2_t comparison = uint32x4x2_t{
1057
- vcltq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1058
- vcltq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1059
-
1060
- minValues.data = uint32x4x2_t{
1061
- vminq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1062
- vminq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1052
+ const uint32x4x2_t comparison =
1053
+ detail::simdlib::binary_func(
1054
+ candidateValues.data, currentValues.data)
1055
+ .call<&vcltq_u32>();
1056
+
1057
+ minValues.data = detail::simdlib::binary_func(
1058
+ candidateValues.data, currentValues.data)
1059
+ .call<&vminq_u32>();
1063
1060
  minIndices.data = uint32x4x2_t{
1064
1061
  vbslq_u32(
1065
1062
  comparison.val[0],
@@ -1070,9 +1067,9 @@ inline void cmplt_min_max_fast(
1070
1067
  candidateIndices.data.val[1],
1071
1068
  currentIndices.data.val[1])};
1072
1069
 
1073
- maxValues.data = uint32x4x2_t{
1074
- vmaxq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1075
- vmaxq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1070
+ maxValues.data = detail::simdlib::binary_func(
1071
+ candidateValues.data, currentValues.data)
1072
+ .call<&vmaxq_u32>();
1076
1073
  maxIndices.data = uint32x4x2_t{
1077
1074
  vbslq_u32(
1078
1075
  comparison.val[0],
@@ -1167,28 +1164,25 @@ struct simd8float32 {
1167
1164
  return *this;
1168
1165
  }
1169
1166
 
1170
- bool operator==(simd8float32 other) const {
1171
- const auto equals =
1167
+ simd8uint32 operator==(simd8float32 other) const {
1168
+ return simd8uint32{
1172
1169
  detail::simdlib::binary_func<::uint32x4x2_t>(data, other.data)
1173
- .call<&vceqq_f32>();
1174
- const auto equal = vandq_u32(equals.val[0], equals.val[1]);
1175
- return vminvq_u32(equal) == 0xffffffff;
1170
+ .call<&vceqq_f32>()};
1176
1171
  }
1177
1172
 
1178
- bool operator!=(simd8float32 other) const {
1179
- return !(*this == other);
1173
+ simd8uint32 operator!=(simd8float32 other) const {
1174
+ return ~(*this == other);
1180
1175
  }
1181
1176
 
1182
1177
  // Checks whether the other holds exactly the same bytes.
1183
- bool is_same_as(simd8float32 other) const {
1184
- const bool equal0 =
1185
- (vminvq_u32(vceqq_f32(data.val[0], other.data.val[0])) ==
1186
- 0xffffffff);
1187
- const bool equal1 =
1188
- (vminvq_u32(vceqq_f32(data.val[1], other.data.val[1])) ==
1189
- 0xffffffff);
1190
-
1191
- return equal0 && equal1;
1178
+ template <typename T>
1179
+ bool is_same_as(T other) const {
1180
+ const auto o = detail::simdlib::reinterpret_f32(other.data);
1181
+ const auto equals =
1182
+ detail::simdlib::binary_func<::uint32x4x2_t>(data, o)
1183
+ .template call<&vceqq_f32>();
1184
+ const auto equal = vandq_u32(equals.val[0], equals.val[1]);
1185
+ return vminvq_u32(equal) == 0xffffffffu;
1192
1186
  }
1193
1187
 
1194
1188
  std::string tostring() const {
@@ -1302,13 +1296,14 @@ inline void cmplt_min_max_fast(
1302
1296
  simd8uint32& minIndices,
1303
1297
  simd8float32& maxValues,
1304
1298
  simd8uint32& maxIndices) {
1305
- const uint32x4x2_t comparison = uint32x4x2_t{
1306
- vcltq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1307
- vcltq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1308
-
1309
- minValues.data = float32x4x2_t{
1310
- vminq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1311
- vminq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1299
+ const uint32x4x2_t comparison =
1300
+ detail::simdlib::binary_func<::uint32x4x2_t>(
1301
+ candidateValues.data, currentValues.data)
1302
+ .call<&vcltq_f32>();
1303
+
1304
+ minValues.data = detail::simdlib::binary_func(
1305
+ candidateValues.data, currentValues.data)
1306
+ .call<&vminq_f32>();
1312
1307
  minIndices.data = uint32x4x2_t{
1313
1308
  vbslq_u32(
1314
1309
  comparison.val[0],
@@ -1319,9 +1314,9 @@ inline void cmplt_min_max_fast(
1319
1314
  candidateIndices.data.val[1],
1320
1315
  currentIndices.data.val[1])};
1321
1316
 
1322
- maxValues.data = float32x4x2_t{
1323
- vmaxq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1324
- vmaxq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1317
+ maxValues.data = detail::simdlib::binary_func(
1318
+ candidateValues.data, currentValues.data)
1319
+ .call<&vmaxq_f32>();
1325
1320
  maxIndices.data = uint32x4x2_t{
1326
1321
  vbslq_u32(
1327
1322
  comparison.val[0],
@@ -123,7 +123,7 @@ void parallel_merge(
123
123
  }
124
124
  }
125
125
 
126
- }; // namespace
126
+ } // namespace
127
127
 
128
128
  void fvec_argsort(size_t n, const float* vals, size_t* perm) {
129
129
  for (size_t i = 0; i < n; i++) {
@@ -544,7 +544,6 @@ void bucket_sort_inplace_parallel(
544
544
 
545
545
  // in this loop, we write elements collected in the previous round
546
546
  // and collect the elements that are overwritten for the next round
547
- size_t tot_written = 0;
548
547
  int round = 0;
549
548
  for (;;) {
550
549
  #pragma omp barrier
@@ -554,9 +553,6 @@ void bucket_sort_inplace_parallel(
554
553
  n_to_write += to_write_2.lims.back();
555
554
  }
556
555
 
557
- tot_written += n_to_write;
558
- // assert(tot_written <= nval);
559
-
560
556
  #pragma omp master
561
557
  {
562
558
  if (verbose >= 1) {
@@ -689,4 +685,143 @@ void matrix_bucket_sort_inplace(
689
685
  }
690
686
  }
691
687
 
688
+ /** Hashtable implementation for int64 -> int64 with external storage
689
+ * implemented for speed and parallel processing.
690
+ */
691
+
692
+ namespace {
693
+
694
+ int log2_capacity_to_log2_nbucket(int log2_capacity) {
695
+ return log2_capacity < 12 ? 0
696
+ : log2_capacity < 20 ? log2_capacity - 12
697
+ : 10;
698
+ }
699
+
700
+ // https://bigprimes.org/
701
+ int64_t bigprime = 8955327411143;
702
+
703
+ inline int64_t hash_function(int64_t x) {
704
+ return (x * 1000003) % bigprime;
705
+ }
706
+
707
+ } // anonymous namespace
708
+
709
+ void hashtable_int64_to_int64_init(int log2_capacity, int64_t* tab) {
710
+ size_t capacity = (size_t)1 << log2_capacity;
711
+ #pragma omp parallel for
712
+ for (int64_t i = 0; i < capacity; i++) {
713
+ tab[2 * i] = -1;
714
+ tab[2 * i + 1] = -1;
715
+ }
716
+ }
717
+
718
+ void hashtable_int64_to_int64_add(
719
+ int log2_capacity,
720
+ int64_t* tab,
721
+ size_t n,
722
+ const int64_t* keys,
723
+ const int64_t* vals) {
724
+ size_t capacity = (size_t)1 << log2_capacity;
725
+ std::vector<int64_t> hk(n);
726
+ std::vector<uint64_t> bucket_no(n);
727
+ int64_t mask = capacity - 1;
728
+ int log2_nbucket = log2_capacity_to_log2_nbucket(log2_capacity);
729
+ size_t nbucket = (size_t)1 << log2_nbucket;
730
+
731
+ #pragma omp parallel for
732
+ for (int64_t i = 0; i < n; i++) {
733
+ hk[i] = hash_function(keys[i]) & mask;
734
+ bucket_no[i] = hk[i] >> (log2_capacity - log2_nbucket);
735
+ }
736
+
737
+ std::vector<int64_t> lims(nbucket + 1);
738
+ std::vector<int64_t> perm(n);
739
+ bucket_sort(
740
+ n,
741
+ bucket_no.data(),
742
+ nbucket,
743
+ lims.data(),
744
+ perm.data(),
745
+ omp_get_max_threads());
746
+
747
+ int num_errors = 0;
748
+ #pragma omp parallel for reduction(+ : num_errors)
749
+ for (int64_t bucket = 0; bucket < nbucket; bucket++) {
750
+ size_t k0 = bucket << (log2_capacity - log2_nbucket);
751
+ size_t k1 = (bucket + 1) << (log2_capacity - log2_nbucket);
752
+
753
+ for (size_t i = lims[bucket]; i < lims[bucket + 1]; i++) {
754
+ int64_t j = perm[i];
755
+ assert(bucket_no[j] == bucket);
756
+ assert(hk[j] >= k0 && hk[j] < k1);
757
+ size_t slot = hk[j];
758
+ for (;;) {
759
+ if (tab[slot * 2] == -1) { // found!
760
+ tab[slot * 2] = keys[j];
761
+ tab[slot * 2 + 1] = vals[j];
762
+ break;
763
+ } else if (tab[slot * 2] == keys[j]) { // overwrite!
764
+ tab[slot * 2 + 1] = vals[j];
765
+ break;
766
+ }
767
+ slot++;
768
+ if (slot == k1) {
769
+ slot = k0;
770
+ }
771
+ if (slot == hk[j]) { // no free slot left in bucket
772
+ num_errors++;
773
+ break;
774
+ }
775
+ }
776
+ if (num_errors > 0) {
777
+ break;
778
+ }
779
+ }
780
+ }
781
+ FAISS_THROW_IF_NOT_MSG(num_errors == 0, "hashtable capacity exhausted");
782
+ }
783
+
784
+ void hashtable_int64_to_int64_lookup(
785
+ int log2_capacity,
786
+ const int64_t* tab,
787
+ size_t n,
788
+ const int64_t* keys,
789
+ int64_t* vals) {
790
+ size_t capacity = (size_t)1 << log2_capacity;
791
+ std::vector<int64_t> hk(n), bucket_no(n);
792
+ int64_t mask = capacity - 1;
793
+ int log2_nbucket = log2_capacity_to_log2_nbucket(log2_capacity);
794
+
795
+ #pragma omp parallel for
796
+ for (int64_t i = 0; i < n; i++) {
797
+ int64_t k = keys[i];
798
+ int64_t hk = hash_function(k) & mask;
799
+ size_t slot = hk;
800
+
801
+ if (tab[2 * slot] == -1) { // not in table
802
+ vals[i] = -1;
803
+ } else if (tab[2 * slot] == k) { // found!
804
+ vals[i] = tab[2 * slot + 1];
805
+ } else { // need to search in [k0, k1)
806
+ size_t bucket = hk >> (log2_capacity - log2_nbucket);
807
+ size_t k0 = bucket << (log2_capacity - log2_nbucket);
808
+ size_t k1 = (bucket + 1) << (log2_capacity - log2_nbucket);
809
+ for (;;) {
810
+ if (tab[slot * 2] == k) { // found!
811
+ vals[i] = tab[2 * slot + 1];
812
+ break;
813
+ }
814
+ slot++;
815
+ if (slot == k1) {
816
+ slot = k0;
817
+ }
818
+ if (slot == hk) { // bucket is full and not found
819
+ vals[i] = -1;
820
+ break;
821
+ }
822
+ }
823
+ }
824
+ }
825
+ }
826
+
692
827
  } // namespace faiss
@@ -68,4 +68,31 @@ void matrix_bucket_sort_inplace(
68
68
  int64_t* lims,
69
69
  int nt = 0);
70
70
 
71
+ /** Hashtable implementation for int64 -> int64 with external storage
72
+ * implemented for fast batch add and lookup.
73
+ *
74
+ * tab is of size 2 * (1 << log2_capacity)
75
+ * n is the number of elements to add or search
76
+ *
77
+ * adding several values in a same batch: an arbitrary one gets added
78
+ * in different batches: the newer batch overwrites.
79
+ * raises an exception if capacity is exhausted.
80
+ */
81
+
82
+ void hashtable_int64_to_int64_init(int log2_capacity, int64_t* tab);
83
+
84
+ void hashtable_int64_to_int64_add(
85
+ int log2_capacity,
86
+ int64_t* tab,
87
+ size_t n,
88
+ const int64_t* keys,
89
+ const int64_t* vals);
90
+
91
+ void hashtable_int64_to_int64_lookup(
92
+ int log2_capacity,
93
+ const int64_t* tab,
94
+ size_t n,
95
+ const int64_t* keys,
96
+ int64_t* vals);
97
+
71
98
  } // namespace faiss
@@ -28,6 +28,8 @@
28
28
  #include <omp.h>
29
29
 
30
30
  #include <algorithm>
31
+ #include <set>
32
+ #include <type_traits>
31
33
  #include <vector>
32
34
 
33
35
  #include <faiss/impl/AuxIndexStructures.h>
@@ -101,6 +103,9 @@ int sgemv_(
101
103
 
102
104
  namespace faiss {
103
105
 
106
+ // this will be set at load time from GPU Faiss
107
+ std::string gpu_compile_options;
108
+
104
109
  std::string get_compile_options() {
105
110
  std::string options;
106
111
 
@@ -110,13 +115,17 @@ std::string get_compile_options() {
110
115
  #endif
111
116
 
112
117
  #ifdef __AVX2__
113
- options += "AVX2";
118
+ options += "AVX2 ";
119
+ #elif __AVX512F__
120
+ options += "AVX512 ";
114
121
  #elif defined(__aarch64__)
115
- options += "NEON";
122
+ options += "NEON ";
116
123
  #else
117
- options += "GENERIC";
124
+ options += "GENERIC ";
118
125
  #endif
119
126
 
127
+ options += gpu_compile_options;
128
+
120
129
  return options;
121
130
  }
122
131
 
@@ -423,15 +432,35 @@ void bincode_hist(size_t n, size_t nbits, const uint8_t* codes, int* hist) {
423
432
  }
424
433
  }
425
434
 
426
- size_t ivec_checksum(size_t n, const int32_t* asigned) {
427
- const uint32_t* a = reinterpret_cast<const uint32_t*>(asigned);
428
- size_t cs = 112909;
435
+ uint64_t ivec_checksum(size_t n, const int32_t* assigned) {
436
+ const uint32_t* a = reinterpret_cast<const uint32_t*>(assigned);
437
+ uint64_t cs = 112909;
429
438
  while (n--) {
430
439
  cs = cs * 65713 + a[n] * 1686049;
431
440
  }
432
441
  return cs;
433
442
  }
434
443
 
444
+ uint64_t bvec_checksum(size_t n, const uint8_t* a) {
445
+ uint64_t cs = ivec_checksum(n / 4, (const int32_t*)a);
446
+ for (size_t i = n / 4 * 4; i < n; i++) {
447
+ cs = cs * 65713 + a[n] * 1686049;
448
+ }
449
+ return cs;
450
+ }
451
+
452
+ void bvecs_checksum(size_t n, size_t d, const uint8_t* a, uint64_t* cs) {
453
+ // MSVC can't accept unsigned index for #pragma omp parallel for
454
+ // so below codes only accept n <= std::numeric_limits<ssize_t>::max()
455
+ using ssize_t = std::make_signed<std::size_t>::type;
456
+ const ssize_t size = n;
457
+ #pragma omp parallel for if (size > 1000)
458
+ for (ssize_t i_ = 0; i_ < size; i_++) {
459
+ const auto i = static_cast<std::size_t>(i_);
460
+ cs[i] = bvec_checksum(d, a + i * d);
461
+ }
462
+ }
463
+
435
464
  const float* fvecs_maybe_subsample(
436
465
  size_t d,
437
466
  size_t* n,
@@ -528,4 +557,81 @@ bool check_openmp() {
528
557
  return true;
529
558
  }
530
559
 
560
+ namespace {
561
+
562
+ template <typename T>
563
+ int64_t count_lt(int64_t n, const T* row, T threshold) {
564
+ for (int64_t i = 0; i < n; i++) {
565
+ if (!(row[i] < threshold)) {
566
+ return i;
567
+ }
568
+ }
569
+ return n;
570
+ }
571
+
572
+ template <typename T>
573
+ int64_t count_gt(int64_t n, const T* row, T threshold) {
574
+ for (int64_t i = 0; i < n; i++) {
575
+ if (!(row[i] > threshold)) {
576
+ return i;
577
+ }
578
+ }
579
+ return n;
580
+ }
581
+
582
+ } // namespace
583
+
584
+ template <typename T>
585
+ void CombinerRangeKNN<T>::compute_sizes(int64_t* L_res_2) {
586
+ this->L_res = L_res_2;
587
+ L_res_2[0] = 0;
588
+ int64_t j = 0;
589
+ for (int64_t i = 0; i < nq; i++) {
590
+ int64_t n_in;
591
+ if (!mask || !mask[i]) {
592
+ const T* row = D + i * k;
593
+ n_in = keep_max ? count_gt(k, row, r2) : count_lt(k, row, r2);
594
+ } else {
595
+ n_in = lim_remain[j + 1] - lim_remain[j];
596
+ j++;
597
+ }
598
+ L_res_2[i + 1] = n_in; // L_res_2[i] + n_in;
599
+ }
600
+ // cumsum
601
+ for (int64_t i = 0; i < nq; i++) {
602
+ L_res_2[i + 1] += L_res_2[i];
603
+ }
604
+ }
605
+
606
+ template <typename T>
607
+ void CombinerRangeKNN<T>::write_result(T* D_res, int64_t* I_res) {
608
+ FAISS_THROW_IF_NOT(L_res);
609
+ int64_t j = 0;
610
+ for (int64_t i = 0; i < nq; i++) {
611
+ int64_t n_in = L_res[i + 1] - L_res[i];
612
+ T* D_row = D_res + L_res[i];
613
+ int64_t* I_row = I_res + L_res[i];
614
+ if (!mask || !mask[i]) {
615
+ memcpy(D_row, D + i * k, n_in * sizeof(*D_row));
616
+ memcpy(I_row, I + i * k, n_in * sizeof(*I_row));
617
+ } else {
618
+ memcpy(D_row, D_remain + lim_remain[j], n_in * sizeof(*D_row));
619
+ memcpy(I_row, I_remain + lim_remain[j], n_in * sizeof(*I_row));
620
+ j++;
621
+ }
622
+ }
623
+ }
624
+
625
+ // explicit template instantiations
626
+ template struct CombinerRangeKNN<float>;
627
+ template struct CombinerRangeKNN<int16_t>;
628
+
629
+ void CodeSet::insert(size_t n, const uint8_t* codes, bool* inserted) {
630
+ for (size_t i = 0; i < n; i++) {
631
+ auto res = s.insert(
632
+ std::vector<uint8_t>(codes + i * d, codes + i * d + d));
633
+ inserted[i] = res.second;
634
+ }
635
+ }
636
+
531
637
  } // namespace faiss