faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -123,7 +123,7 @@ void parallel_merge(
123
123
  }
124
124
  }
125
125
 
126
- }; // namespace
126
+ } // namespace
127
127
 
128
128
  void fvec_argsort(size_t n, const float* vals, size_t* perm) {
129
129
  for (size_t i = 0; i < n; i++) {
@@ -544,7 +544,6 @@ void bucket_sort_inplace_parallel(
544
544
 
545
545
  // in this loop, we write elements collected in the previous round
546
546
  // and collect the elements that are overwritten for the next round
547
- size_t tot_written = 0;
548
547
  int round = 0;
549
548
  for (;;) {
550
549
  #pragma omp barrier
@@ -554,9 +553,6 @@ void bucket_sort_inplace_parallel(
554
553
  n_to_write += to_write_2.lims.back();
555
554
  }
556
555
 
557
- tot_written += n_to_write;
558
- // assert(tot_written <= nval);
559
-
560
556
  #pragma omp master
561
557
  {
562
558
  if (verbose >= 1) {
@@ -689,4 +685,143 @@ void matrix_bucket_sort_inplace(
689
685
  }
690
686
  }
691
687
 
688
+ /** Hashtable implementation for int64 -> int64 with external storage
689
+ * implemented for speed and parallel processing.
690
+ */
691
+
692
+ namespace {
693
+
694
+ int log2_capacity_to_log2_nbucket(int log2_capacity) {
695
+ return log2_capacity < 12 ? 0
696
+ : log2_capacity < 20 ? log2_capacity - 12
697
+ : 10;
698
+ }
699
+
700
+ // https://bigprimes.org/
701
+ int64_t bigprime = 8955327411143;
702
+
703
+ inline int64_t hash_function(int64_t x) {
704
+ return (x * 1000003) % bigprime;
705
+ }
706
+
707
+ } // anonymous namespace
708
+
709
+ void hashtable_int64_to_int64_init(int log2_capacity, int64_t* tab) {
710
+ size_t capacity = (size_t)1 << log2_capacity;
711
+ #pragma omp parallel for
712
+ for (int64_t i = 0; i < capacity; i++) {
713
+ tab[2 * i] = -1;
714
+ tab[2 * i + 1] = -1;
715
+ }
716
+ }
717
+
718
+ void hashtable_int64_to_int64_add(
719
+ int log2_capacity,
720
+ int64_t* tab,
721
+ size_t n,
722
+ const int64_t* keys,
723
+ const int64_t* vals) {
724
+ size_t capacity = (size_t)1 << log2_capacity;
725
+ std::vector<int64_t> hk(n);
726
+ std::vector<uint64_t> bucket_no(n);
727
+ int64_t mask = capacity - 1;
728
+ int log2_nbucket = log2_capacity_to_log2_nbucket(log2_capacity);
729
+ size_t nbucket = (size_t)1 << log2_nbucket;
730
+
731
+ #pragma omp parallel for
732
+ for (int64_t i = 0; i < n; i++) {
733
+ hk[i] = hash_function(keys[i]) & mask;
734
+ bucket_no[i] = hk[i] >> (log2_capacity - log2_nbucket);
735
+ }
736
+
737
+ std::vector<int64_t> lims(nbucket + 1);
738
+ std::vector<int64_t> perm(n);
739
+ bucket_sort(
740
+ n,
741
+ bucket_no.data(),
742
+ nbucket,
743
+ lims.data(),
744
+ perm.data(),
745
+ omp_get_max_threads());
746
+
747
+ int num_errors = 0;
748
+ #pragma omp parallel for reduction(+ : num_errors)
749
+ for (int64_t bucket = 0; bucket < nbucket; bucket++) {
750
+ size_t k0 = bucket << (log2_capacity - log2_nbucket);
751
+ size_t k1 = (bucket + 1) << (log2_capacity - log2_nbucket);
752
+
753
+ for (size_t i = lims[bucket]; i < lims[bucket + 1]; i++) {
754
+ int64_t j = perm[i];
755
+ assert(bucket_no[j] == bucket);
756
+ assert(hk[j] >= k0 && hk[j] < k1);
757
+ size_t slot = hk[j];
758
+ for (;;) {
759
+ if (tab[slot * 2] == -1) { // found!
760
+ tab[slot * 2] = keys[j];
761
+ tab[slot * 2 + 1] = vals[j];
762
+ break;
763
+ } else if (tab[slot * 2] == keys[j]) { // overwrite!
764
+ tab[slot * 2 + 1] = vals[j];
765
+ break;
766
+ }
767
+ slot++;
768
+ if (slot == k1) {
769
+ slot = k0;
770
+ }
771
+ if (slot == hk[j]) { // no free slot left in bucket
772
+ num_errors++;
773
+ break;
774
+ }
775
+ }
776
+ if (num_errors > 0) {
777
+ break;
778
+ }
779
+ }
780
+ }
781
+ FAISS_THROW_IF_NOT_MSG(num_errors == 0, "hashtable capacity exhausted");
782
+ }
783
+
784
+ void hashtable_int64_to_int64_lookup(
785
+ int log2_capacity,
786
+ const int64_t* tab,
787
+ size_t n,
788
+ const int64_t* keys,
789
+ int64_t* vals) {
790
+ size_t capacity = (size_t)1 << log2_capacity;
791
+ std::vector<int64_t> hk(n), bucket_no(n);
792
+ int64_t mask = capacity - 1;
793
+ int log2_nbucket = log2_capacity_to_log2_nbucket(log2_capacity);
794
+
795
+ #pragma omp parallel for
796
+ for (int64_t i = 0; i < n; i++) {
797
+ int64_t k = keys[i];
798
+ int64_t hk = hash_function(k) & mask;
799
+ size_t slot = hk;
800
+
801
+ if (tab[2 * slot] == -1) { // not in table
802
+ vals[i] = -1;
803
+ } else if (tab[2 * slot] == k) { // found!
804
+ vals[i] = tab[2 * slot + 1];
805
+ } else { // need to search in [k0, k1)
806
+ size_t bucket = hk >> (log2_capacity - log2_nbucket);
807
+ size_t k0 = bucket << (log2_capacity - log2_nbucket);
808
+ size_t k1 = (bucket + 1) << (log2_capacity - log2_nbucket);
809
+ for (;;) {
810
+ if (tab[slot * 2] == k) { // found!
811
+ vals[i] = tab[2 * slot + 1];
812
+ break;
813
+ }
814
+ slot++;
815
+ if (slot == k1) {
816
+ slot = k0;
817
+ }
818
+ if (slot == hk) { // bucket is full and not found
819
+ vals[i] = -1;
820
+ break;
821
+ }
822
+ }
823
+ }
824
+ }
825
+ }
826
+
692
827
  } // namespace faiss
@@ -68,4 +68,31 @@ void matrix_bucket_sort_inplace(
68
68
  int64_t* lims,
69
69
  int nt = 0);
70
70
 
71
+ /** Hashtable implementation for int64 -> int64 with external storage
72
+ * implemented for fast batch add and lookup.
73
+ *
74
+ * tab is of size 2 * (1 << log2_capacity)
75
+ * n is the number of elements to add or search
76
+ *
77
+ * adding several values in a same batch: an arbitrary one gets added
78
+ * in different batches: the newer batch overwrites.
79
+ * raises an exception if capacity is exhausted.
80
+ */
81
+
82
+ void hashtable_int64_to_int64_init(int log2_capacity, int64_t* tab);
83
+
84
+ void hashtable_int64_to_int64_add(
85
+ int log2_capacity,
86
+ int64_t* tab,
87
+ size_t n,
88
+ const int64_t* keys,
89
+ const int64_t* vals);
90
+
91
+ void hashtable_int64_to_int64_lookup(
92
+ int log2_capacity,
93
+ const int64_t* tab,
94
+ size_t n,
95
+ const int64_t* keys,
96
+ int64_t* vals);
97
+
71
98
  } // namespace faiss
@@ -0,0 +1,176 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ // This file contains transposing kernels for AVX512 for // tiny float/int32
11
+ // matrices, such as 16x2.
12
+
13
+ #ifdef __AVX512F__
14
+
15
+ #include <immintrin.h>
16
+
17
+ namespace faiss {
18
+
19
+ // 16x2 -> 2x16
20
+ inline void transpose_16x2(
21
+ const __m512 i0,
22
+ const __m512 i1,
23
+ __m512& o0,
24
+ __m512& o1) {
25
+ // assume we have the following input:
26
+ // i0: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
27
+ // i1: 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
28
+
29
+ // 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27
30
+ const __m512 r0 = _mm512_shuffle_f32x4(i0, i1, _MM_SHUFFLE(2, 0, 2, 0));
31
+ // 4 5 6 7 12 13 14 15 20 21 22 23 28 29 30 31
32
+ const __m512 r1 = _mm512_shuffle_f32x4(i0, i1, _MM_SHUFFLE(3, 1, 3, 1));
33
+
34
+ // 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30
35
+ o0 = _mm512_shuffle_ps(r0, r1, _MM_SHUFFLE(2, 0, 2, 0));
36
+ // 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31
37
+ o1 = _mm512_shuffle_ps(r0, r1, _MM_SHUFFLE(3, 1, 3, 1));
38
+ }
39
+
40
+ // 16x4 -> 4x16
41
+ inline void transpose_16x4(
42
+ const __m512 i0,
43
+ const __m512 i1,
44
+ const __m512 i2,
45
+ const __m512 i3,
46
+ __m512& o0,
47
+ __m512& o1,
48
+ __m512& o2,
49
+ __m512& o3) {
50
+ // assume that we have the following input:
51
+ // i0: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
52
+ // i1: 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
53
+ // i2: 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
54
+ // i3: 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
55
+
56
+ // 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27
57
+ const __m512 r0 = _mm512_shuffle_f32x4(i0, i1, _MM_SHUFFLE(2, 0, 2, 0));
58
+ // 4 5 6 7 12 13 14 15 20 21 22 23 28 29 30 31
59
+ const __m512 r1 = _mm512_shuffle_f32x4(i0, i1, _MM_SHUFFLE(3, 1, 3, 1));
60
+ // 32 33 34 35 40 41 42 43 48 49 50 51 56 57 58 59
61
+ const __m512 r2 = _mm512_shuffle_f32x4(i2, i3, _MM_SHUFFLE(2, 0, 2, 0));
62
+ // 52 53 54 55 60 61 62 63 52 53 54 55 60 61 62 63
63
+ const __m512 r3 = _mm512_shuffle_f32x4(i2, i3, _MM_SHUFFLE(3, 1, 3, 1));
64
+
65
+ // 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30
66
+ const __m512 t0 = _mm512_shuffle_ps(r0, r1, _MM_SHUFFLE(2, 0, 2, 0));
67
+ // 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31
68
+ const __m512 t1 = _mm512_shuffle_ps(r0, r1, _MM_SHUFFLE(3, 1, 3, 1));
69
+ // 32 34 52 54 40 42 60 62 48 50 52 54 56 58 60 62
70
+ const __m512 t2 = _mm512_shuffle_ps(r2, r3, _MM_SHUFFLE(2, 0, 2, 0));
71
+ // 33 35 53 55 41 43 61 63 49 51 53 55 57 59 61 63
72
+ const __m512 t3 = _mm512_shuffle_ps(r2, r3, _MM_SHUFFLE(3, 1, 3, 1));
73
+
74
+ const __m512i idx0 = _mm512_set_epi32(
75
+ 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
76
+ const __m512i idx1 = _mm512_set_epi32(
77
+ 31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1);
78
+
79
+ // 0 4 8 12 16 20 24 28 32 52 40 60 48 52 56 60
80
+ o0 = _mm512_permutex2var_ps(t0, idx0, t2);
81
+ // 1 5 9 13 17 21 25 29 33 53 41 61 49 53 57 61
82
+ o1 = _mm512_permutex2var_ps(t1, idx0, t3);
83
+ // 2 6 10 14 18 22 26 30 34 54 42 62 50 54 58 62
84
+ o2 = _mm512_permutex2var_ps(t0, idx1, t2);
85
+ // 3 7 11 15 19 23 27 31 35 55 43 63 51 55 59 63
86
+ o3 = _mm512_permutex2var_ps(t1, idx1, t3);
87
+ }
88
+
89
+ // 16x8 -> 8x16 transpose
90
+ inline void transpose_16x8(
91
+ const __m512 i0,
92
+ const __m512 i1,
93
+ const __m512 i2,
94
+ const __m512 i3,
95
+ const __m512 i4,
96
+ const __m512 i5,
97
+ const __m512 i6,
98
+ const __m512 i7,
99
+ __m512& o0,
100
+ __m512& o1,
101
+ __m512& o2,
102
+ __m512& o3,
103
+ __m512& o4,
104
+ __m512& o5,
105
+ __m512& o6,
106
+ __m512& o7) {
107
+ // assume that we have the following input:
108
+ // i0: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
109
+ // i1: 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
110
+ // i2: 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
111
+ // i3: 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
112
+ // i4: 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
113
+ // i5: 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
114
+ // i6: 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
115
+ // i7: 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
116
+
117
+ // 0 16 1 17 4 20 5 21 8 24 9 25 12 28 13 29
118
+ const __m512 r0 = _mm512_unpacklo_ps(i0, i1);
119
+ // 2 18 3 19 6 22 7 23 10 26 11 27 14 30 15 31
120
+ const __m512 r1 = _mm512_unpackhi_ps(i0, i1);
121
+ // 32 48 33 49 36 52 37 53 40 56 41 57 44 60 45 61
122
+ const __m512 r2 = _mm512_unpacklo_ps(i2, i3);
123
+ // 34 50 35 51 38 54 39 55 42 58 43 59 46 62 47 63
124
+ const __m512 r3 = _mm512_unpackhi_ps(i2, i3);
125
+ // 64 80 65 81 68 84 69 85 72 88 73 89 76 92 77 93
126
+ const __m512 r4 = _mm512_unpacklo_ps(i4, i5);
127
+ // 66 82 67 83 70 86 71 87 74 90 75 91 78 94 79 95
128
+ const __m512 r5 = _mm512_unpackhi_ps(i4, i5);
129
+ // 96 112 97 113 100 116 101 117 104 120 105 121 108 124 109 125
130
+ const __m512 r6 = _mm512_unpacklo_ps(i6, i7);
131
+ // 98 114 99 115 102 118 103 119 106 122 107 123 110 126 111 127
132
+ const __m512 r7 = _mm512_unpackhi_ps(i6, i7);
133
+
134
+ // 0 16 32 48 4 20 36 52 8 24 40 56 12 28 44 60
135
+ const __m512 t0 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(1, 0, 1, 0));
136
+ // 1 17 33 49 5 21 37 53 9 25 41 57 13 29 45 61
137
+ const __m512 t1 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 2, 3, 2));
138
+ // 2 18 34 50 6 22 38 54 10 26 42 58 14 30 46 62
139
+ const __m512 t2 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(1, 0, 1, 0));
140
+ // 3 19 35 51 7 23 39 55 11 27 43 59 15 31 47 63
141
+ const __m512 t3 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 2, 3, 2));
142
+ // 64 80 96 112 68 84 100 116 72 88 104 120 76 92 108 124
143
+ const __m512 t4 = _mm512_shuffle_ps(r4, r6, _MM_SHUFFLE(1, 0, 1, 0));
144
+ // 65 81 97 113 69 85 101 117 73 89 105 121 77 93 109 125
145
+ const __m512 t5 = _mm512_shuffle_ps(r4, r6, _MM_SHUFFLE(3, 2, 3, 2));
146
+ // 66 82 98 114 70 86 102 118 74 90 106 122 78 94 110 126
147
+ const __m512 t6 = _mm512_shuffle_ps(r5, r7, _MM_SHUFFLE(1, 0, 1, 0));
148
+ // 67 83 99 115 71 87 103 119 75 91 107 123 79 95 111 127
149
+ const __m512 t7 = _mm512_shuffle_ps(r5, r7, _MM_SHUFFLE(3, 2, 3, 2));
150
+
151
+ const __m512i idx0 = _mm512_set_epi32(
152
+ 27, 19, 26, 18, 25, 17, 24, 16, 11, 3, 10, 2, 9, 1, 8, 0);
153
+ const __m512i idx1 = _mm512_set_epi32(
154
+ 31, 23, 30, 22, 29, 21, 28, 20, 15, 7, 14, 6, 13, 5, 12, 4);
155
+
156
+ // 0 8 16 24 32 40 48 56 64 72 80 88 96 104 112 120
157
+ o0 = _mm512_permutex2var_ps(t0, idx0, t4);
158
+ // 1 9 17 25 33 41 49 57 65 73 81 89 97 105 113 121
159
+ o1 = _mm512_permutex2var_ps(t1, idx0, t5);
160
+ // 2 10 18 26 34 42 50 58 66 74 82 90 98 106 114 122
161
+ o2 = _mm512_permutex2var_ps(t2, idx0, t6);
162
+ // 3 11 19 27 35 43 51 59 67 75 83 91 99 107 115 123
163
+ o3 = _mm512_permutex2var_ps(t3, idx0, t7);
164
+ // 4 12 20 28 36 44 52 60 68 76 84 92 100 108 116 124
165
+ o4 = _mm512_permutex2var_ps(t0, idx1, t4);
166
+ // 5 13 21 29 37 45 53 61 69 77 85 93 101 109 117 125
167
+ o5 = _mm512_permutex2var_ps(t1, idx1, t5);
168
+ // 6 14 22 30 38 46 54 62 70 78 86 94 102 110 118 126
169
+ o6 = _mm512_permutex2var_ps(t2, idx1, t6);
170
+ // 7 15 23 31 39 47 55 63 71 79 87 95 103 111 119 127
171
+ o7 = _mm512_permutex2var_ps(t3, idx1, t7);
172
+ }
173
+
174
+ } // namespace faiss
175
+
176
+ #endif
@@ -7,6 +7,7 @@
7
7
 
8
8
  // -*- c++ -*-
9
9
 
10
+ #include <faiss/Index.h>
10
11
  #include <faiss/utils/utils.h>
11
12
 
12
13
  #include <cassert>
@@ -28,6 +29,8 @@
28
29
  #include <omp.h>
29
30
 
30
31
  #include <algorithm>
32
+ #include <set>
33
+ #include <type_traits>
31
34
  #include <vector>
32
35
 
33
36
  #include <faiss/impl/AuxIndexStructures.h>
@@ -101,6 +104,9 @@ int sgemv_(
101
104
 
102
105
  namespace faiss {
103
106
 
107
+ // this will be set at load time from GPU Faiss
108
+ std::string gpu_compile_options;
109
+
104
110
  std::string get_compile_options() {
105
111
  std::string options;
106
112
 
@@ -109,17 +115,27 @@ std::string get_compile_options() {
109
115
  options += "OPTIMIZE ";
110
116
  #endif
111
117
 
112
- #ifdef __AVX2__
113
- options += "AVX2";
118
+ #ifdef __AVX512F__
119
+ options += "AVX512 ";
120
+ #elif defined(__AVX2__)
121
+ options += "AVX2 ";
122
+ #elif defined(__ARM_FEATURE_SVE)
123
+ options += "SVE NEON ";
114
124
  #elif defined(__aarch64__)
115
- options += "NEON";
125
+ options += "NEON ";
116
126
  #else
117
- options += "GENERIC";
127
+ options += "GENERIC ";
118
128
  #endif
119
129
 
130
+ options += gpu_compile_options;
131
+
120
132
  return options;
121
133
  }
122
134
 
135
+ std::string get_version() {
136
+ return VERSION_STRING;
137
+ }
138
+
123
139
  #ifdef _MSC_VER
124
140
  double getmillisecs() {
125
141
  LARGE_INTEGER ts;
@@ -423,15 +439,35 @@ void bincode_hist(size_t n, size_t nbits, const uint8_t* codes, int* hist) {
423
439
  }
424
440
  }
425
441
 
426
- size_t ivec_checksum(size_t n, const int32_t* asigned) {
427
- const uint32_t* a = reinterpret_cast<const uint32_t*>(asigned);
428
- size_t cs = 112909;
442
+ uint64_t ivec_checksum(size_t n, const int32_t* assigned) {
443
+ const uint32_t* a = reinterpret_cast<const uint32_t*>(assigned);
444
+ uint64_t cs = 112909;
429
445
  while (n--) {
430
446
  cs = cs * 65713 + a[n] * 1686049;
431
447
  }
432
448
  return cs;
433
449
  }
434
450
 
451
+ uint64_t bvec_checksum(size_t n, const uint8_t* a) {
452
+ uint64_t cs = ivec_checksum(n / 4, (const int32_t*)a);
453
+ for (size_t i = n / 4 * 4; i < n; i++) {
454
+ cs = cs * 65713 + a[n] * 1686049;
455
+ }
456
+ return cs;
457
+ }
458
+
459
+ void bvecs_checksum(size_t n, size_t d, const uint8_t* a, uint64_t* cs) {
460
+ // MSVC can't accept unsigned index for #pragma omp parallel for
461
+ // so below codes only accept n <= std::numeric_limits<ssize_t>::max()
462
+ using ssize_t = std::make_signed<std::size_t>::type;
463
+ const ssize_t size = n;
464
+ #pragma omp parallel for if (size > 1000)
465
+ for (ssize_t i_ = 0; i_ < size; i_++) {
466
+ const auto i = static_cast<std::size_t>(i_);
467
+ cs[i] = bvec_checksum(d, a + i * d);
468
+ }
469
+ }
470
+
435
471
  const float* fvecs_maybe_subsample(
436
472
  size_t d,
437
473
  size_t* n,
@@ -528,4 +564,81 @@ bool check_openmp() {
528
564
  return true;
529
565
  }
530
566
 
567
+ namespace {
568
+
569
+ template <typename T>
570
+ int64_t count_lt(int64_t n, const T* row, T threshold) {
571
+ for (int64_t i = 0; i < n; i++) {
572
+ if (!(row[i] < threshold)) {
573
+ return i;
574
+ }
575
+ }
576
+ return n;
577
+ }
578
+
579
+ template <typename T>
580
+ int64_t count_gt(int64_t n, const T* row, T threshold) {
581
+ for (int64_t i = 0; i < n; i++) {
582
+ if (!(row[i] > threshold)) {
583
+ return i;
584
+ }
585
+ }
586
+ return n;
587
+ }
588
+
589
+ } // namespace
590
+
591
+ template <typename T>
592
+ void CombinerRangeKNN<T>::compute_sizes(int64_t* L_res_2) {
593
+ this->L_res = L_res_2;
594
+ L_res_2[0] = 0;
595
+ int64_t j = 0;
596
+ for (int64_t i = 0; i < nq; i++) {
597
+ int64_t n_in;
598
+ if (!mask || !mask[i]) {
599
+ const T* row = D + i * k;
600
+ n_in = keep_max ? count_gt(k, row, r2) : count_lt(k, row, r2);
601
+ } else {
602
+ n_in = lim_remain[j + 1] - lim_remain[j];
603
+ j++;
604
+ }
605
+ L_res_2[i + 1] = n_in; // L_res_2[i] + n_in;
606
+ }
607
+ // cumsum
608
+ for (int64_t i = 0; i < nq; i++) {
609
+ L_res_2[i + 1] += L_res_2[i];
610
+ }
611
+ }
612
+
613
+ template <typename T>
614
+ void CombinerRangeKNN<T>::write_result(T* D_res, int64_t* I_res) {
615
+ FAISS_THROW_IF_NOT(L_res);
616
+ int64_t j = 0;
617
+ for (int64_t i = 0; i < nq; i++) {
618
+ int64_t n_in = L_res[i + 1] - L_res[i];
619
+ T* D_row = D_res + L_res[i];
620
+ int64_t* I_row = I_res + L_res[i];
621
+ if (!mask || !mask[i]) {
622
+ memcpy(D_row, D + i * k, n_in * sizeof(*D_row));
623
+ memcpy(I_row, I + i * k, n_in * sizeof(*I_row));
624
+ } else {
625
+ memcpy(D_row, D_remain + lim_remain[j], n_in * sizeof(*D_row));
626
+ memcpy(I_row, I_remain + lim_remain[j], n_in * sizeof(*I_row));
627
+ j++;
628
+ }
629
+ }
630
+ }
631
+
632
+ // explicit template instantiations
633
+ template struct CombinerRangeKNN<float>;
634
+ template struct CombinerRangeKNN<int16_t>;
635
+
636
+ void CodeSet::insert(size_t n, const uint8_t* codes, bool* inserted) {
637
+ for (size_t i = 0; i < n; i++) {
638
+ auto res = s.insert(
639
+ std::vector<uint8_t>(codes + i * d, codes + i * d + d));
640
+ inserted[i] = res.second;
641
+ }
642
+ }
643
+
531
644
  } // namespace faiss
@@ -17,7 +17,9 @@
17
17
  #define FAISS_utils_h
18
18
 
19
19
  #include <stdint.h>
20
+ #include <set>
20
21
  #include <string>
22
+ #include <vector>
21
23
 
22
24
  #include <faiss/impl/platform_macros.h>
23
25
  #include <faiss/utils/Heap.h>
@@ -35,6 +37,9 @@ std::string get_compile_options();
35
37
  * Get some stats about the system
36
38
  **************************************************/
37
39
 
40
+ // Expose FAISS version as a string
41
+ std::string get_version();
42
+
38
43
  /// ms elapsed since some arbitrary epoch
39
44
  double getmillisecs();
40
45
 
@@ -47,25 +52,6 @@ uint64_t get_cycles();
47
52
  * Misc matrix and vector manipulation functions
48
53
  ***************************************************************************/
49
54
 
50
- /** compute c := a + bf * b for a, b and c tables
51
- *
52
- * @param n size of the tables
53
- * @param a size n
54
- * @param b size n
55
- * @param c restult table, size n
56
- */
57
- void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c);
58
-
59
- /** same as fvec_madd, also return index of the min of the result table
60
- * @return index of the min of table c
61
- */
62
- int fvec_madd_and_argmin(
63
- size_t n,
64
- const float* a,
65
- float bf,
66
- const float* b,
67
- float* c);
68
-
69
55
  /* perform a reflection (not an efficient implementation, just for test ) */
70
56
  void reflection(const float* u, float* x, size_t n, size_t d, size_t nu);
71
57
 
@@ -121,7 +107,19 @@ int ivec_hist(size_t n, const int* v, int vmax, int* hist);
121
107
  void bincode_hist(size_t n, size_t nbits, const uint8_t* codes, int* hist);
122
108
 
123
109
  /// compute a checksum on a table.
124
- size_t ivec_checksum(size_t n, const int32_t* a);
110
+ uint64_t ivec_checksum(size_t n, const int32_t* a);
111
+
112
+ /// compute a checksum on a table.
113
+ uint64_t bvec_checksum(size_t n, const uint8_t* a);
114
+
115
+ /** compute checksums for the rows of a matrix
116
+ *
117
+ * @param n number of rows
118
+ * @param d size per row
119
+ * @param a matrix to handle, size n * d
120
+ * @param cs output checksums, size n
121
+ */
122
+ void bvecs_checksum(size_t n, size_t d, const uint8_t* a, uint64_t* cs);
125
123
 
126
124
  /** random subsamples a set of vectors if there are too many of them
127
125
  *
@@ -163,6 +161,48 @@ uint64_t hash_bytes(const uint8_t* bytes, int64_t n);
163
161
  /** Whether OpenMP annotations were respected. */
164
162
  bool check_openmp();
165
163
 
164
+ /** This class is used to combine range and knn search results
165
+ * in contrib.exhaustive_search.range_search_gpu */
166
+
167
+ template <typename T>
168
+ struct CombinerRangeKNN {
169
+ int64_t nq; /// nb of queries
170
+ size_t k; /// number of neighbors for the knn search part
171
+ T r2; /// range search radius
172
+ bool keep_max; /// whether to keep max values instead of min.
173
+
174
+ CombinerRangeKNN(int64_t nq, size_t k, T r2, bool keep_max)
175
+ : nq(nq), k(k), r2(r2), keep_max(keep_max) {}
176
+
177
+ /// Knn search results
178
+ const int64_t* I = nullptr; /// size nq * k
179
+ const T* D = nullptr; /// size nq * k
180
+
181
+ /// optional: range search results (ignored if mask is NULL)
182
+ const bool* mask =
183
+ nullptr; /// mask for where knn results are valid, size nq
184
+ // range search results for remaining entries nrange = sum(mask)
185
+ const int64_t* lim_remain = nullptr; /// size nrange + 1
186
+ const T* D_remain = nullptr; /// size lim_remain[nrange]
187
+ const int64_t* I_remain = nullptr; /// size lim_remain[nrange]
188
+
189
+ const int64_t* L_res = nullptr; /// size nq + 1
190
+ // Phase 1: compute sizes into limits array (of size nq + 1)
191
+ void compute_sizes(int64_t* L_res);
192
+
193
+ /// Phase 2: caller allocates D_res and I_res (size L_res[nq])
194
+ /// Phase 3: fill in D_res and I_res
195
+ void write_result(T* D_res, int64_t* I_res);
196
+ };
197
+
198
+ struct CodeSet {
199
+ size_t d;
200
+ std::set<std::vector<uint8_t>> s;
201
+
202
+ explicit CodeSet(size_t d) : d(d) {}
203
+ void insert(size_t n, const uint8_t* codes, bool* inserted);
204
+ };
205
+
166
206
  } // namespace faiss
167
207
 
168
208
  #endif /* FAISS_utils_h */