faiss 0.5.3 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +4 -4
- data/ext/faiss/index.cpp +63 -45
- data/ext/faiss/index_binary.cpp +37 -27
- data/ext/faiss/kmeans.cpp +9 -8
- data/ext/faiss/pca_matrix.cpp +9 -7
- data/ext/faiss/product_quantizer.cpp +13 -11
- data/ext/faiss/utils.cpp +4 -2
- data/ext/faiss/utils.h +4 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +214 -82
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +97 -249
- data/vendor/faiss/faiss/Clustering.h +18 -0
- data/vendor/faiss/faiss/IVFlib.cpp +67 -44
- data/vendor/faiss/faiss/Index.cpp +25 -12
- data/vendor/faiss/faiss/Index.h +26 -4
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
- data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
- data/vendor/faiss/faiss/IndexFastScan.h +35 -24
- data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
- data/vendor/faiss/faiss/IndexFlat.h +32 -14
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
- data/vendor/faiss/faiss/IndexHNSW.h +30 -14
- data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
- data/vendor/faiss/faiss/IndexIVF.h +47 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
- data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
- data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
- data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
- data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
- data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +13 -13
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +29 -6
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
- data/vendor/faiss/faiss/VectorTransform.h +39 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +55 -51
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
- data/vendor/faiss/faiss/impl/HNSW.h +21 -40
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
- data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
- data/vendor/faiss/faiss/impl/NSG.h +20 -10
- data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
- data/vendor/faiss/faiss/impl/Panorama.h +265 -78
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
- data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +115 -28
- data/vendor/faiss/faiss/index_io.h +53 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +507 -559
- data/vendor/faiss/faiss/utils/distances.h +118 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
- data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
- data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +21 -14
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +156 -42
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -11,15 +11,14 @@
|
|
|
11
11
|
#include <faiss/VectorTransform.h>
|
|
12
12
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
13
13
|
|
|
14
|
-
#include <chrono>
|
|
15
14
|
#include <cinttypes>
|
|
16
15
|
#include <cmath>
|
|
17
16
|
#include <cstdio>
|
|
18
17
|
#include <cstring>
|
|
19
|
-
|
|
20
|
-
#include <omp.h>
|
|
18
|
+
#include <limits>
|
|
21
19
|
|
|
22
20
|
#include <faiss/IndexFlat.h>
|
|
21
|
+
#include <faiss/impl/ClusteringHelpers.h>
|
|
23
22
|
#include <faiss/impl/FaissAssert.h>
|
|
24
23
|
#include <faiss/impl/kmeans1d.h>
|
|
25
24
|
#include <faiss/utils/distances.h>
|
|
@@ -28,10 +27,10 @@
|
|
|
28
27
|
|
|
29
28
|
namespace faiss {
|
|
30
29
|
|
|
31
|
-
Clustering::Clustering(int
|
|
30
|
+
Clustering::Clustering(int d_, int k_) : d(d_), k(k_) {}
|
|
32
31
|
|
|
33
|
-
Clustering::Clustering(int
|
|
34
|
-
: ClusteringParameters(cp), d(
|
|
32
|
+
Clustering::Clustering(int d_, int k_, const ClusteringParameters& cp)
|
|
33
|
+
: ClusteringParameters(cp), d(d_), k(k_) {}
|
|
35
34
|
|
|
36
35
|
void Clustering::post_process_centroids() {
|
|
37
36
|
if (spherical) {
|
|
@@ -58,213 +57,6 @@ void Clustering::train(
|
|
|
58
57
|
weights);
|
|
59
58
|
}
|
|
60
59
|
|
|
61
|
-
namespace {
|
|
62
|
-
|
|
63
|
-
uint64_t get_actual_rng_seed(const int seed) {
|
|
64
|
-
return (seed >= 0)
|
|
65
|
-
? seed
|
|
66
|
-
: static_cast<uint64_t>(std::chrono::high_resolution_clock::now()
|
|
67
|
-
.time_since_epoch()
|
|
68
|
-
.count());
|
|
69
|
-
}
|
|
70
|
-
|
|
71
|
-
idx_t subsample_training_set(
|
|
72
|
-
const Clustering& clus,
|
|
73
|
-
idx_t nx,
|
|
74
|
-
const uint8_t* x,
|
|
75
|
-
size_t line_size,
|
|
76
|
-
const float* weights,
|
|
77
|
-
uint8_t** x_out,
|
|
78
|
-
float** weights_out) {
|
|
79
|
-
if (clus.verbose) {
|
|
80
|
-
printf("Sampling a subset of %zd / %" PRId64 " for training\n",
|
|
81
|
-
clus.k * clus.max_points_per_centroid,
|
|
82
|
-
nx);
|
|
83
|
-
}
|
|
84
|
-
|
|
85
|
-
const uint64_t actual_seed = get_actual_rng_seed(clus.seed);
|
|
86
|
-
|
|
87
|
-
std::vector<int> perm;
|
|
88
|
-
if (clus.use_faster_subsampling) {
|
|
89
|
-
// use subsampling with splitmix64 rng
|
|
90
|
-
SplitMix64RandomGenerator rng(actual_seed);
|
|
91
|
-
|
|
92
|
-
const idx_t new_nx = clus.k * clus.max_points_per_centroid;
|
|
93
|
-
perm.resize(new_nx);
|
|
94
|
-
for (idx_t i = 0; i < new_nx; i++) {
|
|
95
|
-
perm[i] = rng.rand_int(nx);
|
|
96
|
-
}
|
|
97
|
-
} else {
|
|
98
|
-
// use subsampling with a default std rng
|
|
99
|
-
perm.resize(nx);
|
|
100
|
-
rand_perm(perm.data(), nx, actual_seed);
|
|
101
|
-
}
|
|
102
|
-
|
|
103
|
-
nx = clus.k * clus.max_points_per_centroid;
|
|
104
|
-
uint8_t* x_new = new uint8_t[nx * line_size];
|
|
105
|
-
*x_out = x_new;
|
|
106
|
-
|
|
107
|
-
// might be worth omp-ing as well
|
|
108
|
-
for (idx_t i = 0; i < nx; i++) {
|
|
109
|
-
memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
|
|
110
|
-
}
|
|
111
|
-
if (weights) {
|
|
112
|
-
float* weights_new = new float[nx];
|
|
113
|
-
for (idx_t i = 0; i < nx; i++) {
|
|
114
|
-
weights_new[i] = weights[perm[i]];
|
|
115
|
-
}
|
|
116
|
-
*weights_out = weights_new;
|
|
117
|
-
} else {
|
|
118
|
-
*weights_out = nullptr;
|
|
119
|
-
}
|
|
120
|
-
return nx;
|
|
121
|
-
}
|
|
122
|
-
|
|
123
|
-
/** compute centroids as (weighted) sum of training points
|
|
124
|
-
*
|
|
125
|
-
* @param x training vectors, size n * code_size (from codec)
|
|
126
|
-
* @param codec how to decode the vectors (if NULL then cast to float*)
|
|
127
|
-
* @param weights per-training vector weight, size n (or NULL)
|
|
128
|
-
* @param assign nearest centroid for each training vector, size n
|
|
129
|
-
* @param k_frozen do not update the k_frozen first centroids
|
|
130
|
-
* @param centroids centroid vectors (output only), size k * d
|
|
131
|
-
* @param hassign histogram of assignments per centroid (size k),
|
|
132
|
-
* should be 0 on input
|
|
133
|
-
*
|
|
134
|
-
*/
|
|
135
|
-
|
|
136
|
-
void compute_centroids(
|
|
137
|
-
size_t d,
|
|
138
|
-
size_t k,
|
|
139
|
-
size_t n,
|
|
140
|
-
size_t k_frozen,
|
|
141
|
-
const uint8_t* x,
|
|
142
|
-
const Index* codec,
|
|
143
|
-
const int64_t* assign,
|
|
144
|
-
const float* weights,
|
|
145
|
-
float* hassign,
|
|
146
|
-
float* centroids) {
|
|
147
|
-
k -= k_frozen;
|
|
148
|
-
centroids += k_frozen * d;
|
|
149
|
-
|
|
150
|
-
memset(centroids, 0, sizeof(*centroids) * d * k);
|
|
151
|
-
|
|
152
|
-
size_t line_size = codec ? codec->sa_code_size() : d * sizeof(float);
|
|
153
|
-
|
|
154
|
-
#pragma omp parallel
|
|
155
|
-
{
|
|
156
|
-
int nt = omp_get_num_threads();
|
|
157
|
-
int rank = omp_get_thread_num();
|
|
158
|
-
|
|
159
|
-
// this thread is taking care of centroids c0:c1
|
|
160
|
-
size_t c0 = (k * rank) / nt;
|
|
161
|
-
size_t c1 = (k * (rank + 1)) / nt;
|
|
162
|
-
std::vector<float> decode_buffer(d);
|
|
163
|
-
|
|
164
|
-
for (size_t i = 0; i < n; i++) {
|
|
165
|
-
int64_t ci = assign[i];
|
|
166
|
-
assert(ci >= 0 && ci < k + k_frozen);
|
|
167
|
-
ci -= k_frozen;
|
|
168
|
-
if (ci >= c0 && ci < c1) {
|
|
169
|
-
float* c = centroids + ci * d;
|
|
170
|
-
const float* xi;
|
|
171
|
-
if (!codec) {
|
|
172
|
-
xi = reinterpret_cast<const float*>(x + i * line_size);
|
|
173
|
-
} else {
|
|
174
|
-
float* xif = decode_buffer.data();
|
|
175
|
-
codec->sa_decode(1, x + i * line_size, xif);
|
|
176
|
-
xi = xif;
|
|
177
|
-
}
|
|
178
|
-
if (weights) {
|
|
179
|
-
float w = weights[i];
|
|
180
|
-
hassign[ci] += w;
|
|
181
|
-
for (size_t j = 0; j < d; j++) {
|
|
182
|
-
c[j] += xi[j] * w;
|
|
183
|
-
}
|
|
184
|
-
} else {
|
|
185
|
-
hassign[ci] += 1.0;
|
|
186
|
-
for (size_t j = 0; j < d; j++) {
|
|
187
|
-
c[j] += xi[j];
|
|
188
|
-
}
|
|
189
|
-
}
|
|
190
|
-
}
|
|
191
|
-
}
|
|
192
|
-
}
|
|
193
|
-
|
|
194
|
-
#pragma omp parallel for
|
|
195
|
-
for (idx_t ci = 0; ci < k; ci++) {
|
|
196
|
-
if (hassign[ci] == 0) {
|
|
197
|
-
continue;
|
|
198
|
-
}
|
|
199
|
-
float norm = 1 / hassign[ci];
|
|
200
|
-
float* c = centroids + ci * d;
|
|
201
|
-
for (size_t j = 0; j < d; j++) {
|
|
202
|
-
c[j] *= norm;
|
|
203
|
-
}
|
|
204
|
-
}
|
|
205
|
-
}
|
|
206
|
-
|
|
207
|
-
// a bit above machine epsilon for float16
|
|
208
|
-
#define EPS (1 / 1024.)
|
|
209
|
-
|
|
210
|
-
/** Handle empty clusters by splitting larger ones.
|
|
211
|
-
*
|
|
212
|
-
* It works by slightly changing the centroids to make 2 clusters from
|
|
213
|
-
* a single one. Takes the same arguments as compute_centroids.
|
|
214
|
-
*
|
|
215
|
-
* @return nb of splitting operations (larger is worse)
|
|
216
|
-
*/
|
|
217
|
-
int split_clusters(
|
|
218
|
-
size_t d,
|
|
219
|
-
size_t k,
|
|
220
|
-
size_t n,
|
|
221
|
-
size_t k_frozen,
|
|
222
|
-
float* hassign,
|
|
223
|
-
float* centroids) {
|
|
224
|
-
k -= k_frozen;
|
|
225
|
-
centroids += k_frozen * d;
|
|
226
|
-
|
|
227
|
-
/* Take care of void clusters */
|
|
228
|
-
size_t nsplit = 0;
|
|
229
|
-
RandomGenerator rng(1234);
|
|
230
|
-
for (size_t ci = 0; ci < k; ci++) {
|
|
231
|
-
if (hassign[ci] == 0) { /* need to redefine a centroid */
|
|
232
|
-
size_t cj;
|
|
233
|
-
for (cj = 0; true; cj = (cj + 1) % k) {
|
|
234
|
-
/* probability to pick this cluster for split */
|
|
235
|
-
float p = (hassign[cj] - 1.0) / (float)(n - k);
|
|
236
|
-
float r = rng.rand_float();
|
|
237
|
-
if (r < p) {
|
|
238
|
-
break; /* found our cluster to be split */
|
|
239
|
-
}
|
|
240
|
-
}
|
|
241
|
-
memcpy(centroids + ci * d,
|
|
242
|
-
centroids + cj * d,
|
|
243
|
-
sizeof(*centroids) * d);
|
|
244
|
-
|
|
245
|
-
/* small symmetric perturbation */
|
|
246
|
-
for (size_t j = 0; j < d; j++) {
|
|
247
|
-
if (j % 2 == 0) {
|
|
248
|
-
centroids[ci * d + j] *= 1 + EPS;
|
|
249
|
-
centroids[cj * d + j] *= 1 - EPS;
|
|
250
|
-
} else {
|
|
251
|
-
centroids[ci * d + j] *= 1 - EPS;
|
|
252
|
-
centroids[cj * d + j] *= 1 + EPS;
|
|
253
|
-
}
|
|
254
|
-
}
|
|
255
|
-
|
|
256
|
-
/* assume even split of the cluster */
|
|
257
|
-
hassign[ci] = hassign[cj] / 2;
|
|
258
|
-
hassign[cj] -= hassign[ci];
|
|
259
|
-
nsplit++;
|
|
260
|
-
}
|
|
261
|
-
}
|
|
262
|
-
|
|
263
|
-
return nsplit;
|
|
264
|
-
}
|
|
265
|
-
|
|
266
|
-
} // namespace
|
|
267
|
-
|
|
268
60
|
void Clustering::train_encoded(
|
|
269
61
|
idx_t nx,
|
|
270
62
|
const uint8_t* x_in,
|
|
@@ -272,7 +64,7 @@ void Clustering::train_encoded(
|
|
|
272
64
|
Index& index,
|
|
273
65
|
const float* weights) {
|
|
274
66
|
FAISS_THROW_IF_NOT_FMT(
|
|
275
|
-
nx >= k,
|
|
67
|
+
nx >= static_cast<idx_t>(k),
|
|
276
68
|
"Number of training points (%" PRId64
|
|
277
69
|
") should be at least "
|
|
278
70
|
"as large as number of clusters (%zd)",
|
|
@@ -280,13 +72,13 @@ void Clustering::train_encoded(
|
|
|
280
72
|
k);
|
|
281
73
|
|
|
282
74
|
FAISS_THROW_IF_NOT_FMT(
|
|
283
|
-
(!codec || codec->d == d),
|
|
75
|
+
(!codec || static_cast<size_t>(codec->d) == d),
|
|
284
76
|
"Codec dimension %d not the same as data dimension %d",
|
|
285
77
|
int(codec->d),
|
|
286
78
|
int(d));
|
|
287
79
|
|
|
288
80
|
FAISS_THROW_IF_NOT_FMT(
|
|
289
|
-
index.d == d,
|
|
81
|
+
static_cast<size_t>(index.d) == d,
|
|
290
82
|
"Index dimension %d not the same as data dimension %d",
|
|
291
83
|
int(index.d),
|
|
292
84
|
int(d));
|
|
@@ -309,16 +101,16 @@ void Clustering::train_encoded(
|
|
|
309
101
|
std::unique_ptr<float[]> del3;
|
|
310
102
|
size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
|
|
311
103
|
|
|
312
|
-
if (nx > k * max_points_per_centroid) {
|
|
104
|
+
if (static_cast<size_t>(nx) > k * max_points_per_centroid) {
|
|
313
105
|
uint8_t* x_new;
|
|
314
106
|
float* weights_new;
|
|
315
|
-
nx = subsample_training_set(
|
|
107
|
+
nx = detail::subsample_training_set(
|
|
316
108
|
*this, nx, x, line_size, weights, &x_new, &weights_new);
|
|
317
109
|
del1.reset(x_new);
|
|
318
110
|
x = x_new;
|
|
319
111
|
del3.reset(weights_new);
|
|
320
112
|
weights = weights_new;
|
|
321
|
-
} else if (nx < k * min_points_per_centroid) {
|
|
113
|
+
} else if (static_cast<size_t>(nx) < k * min_points_per_centroid) {
|
|
322
114
|
fprintf(stderr,
|
|
323
115
|
"WARNING clustering %" PRId64
|
|
324
116
|
" points to %zd centroids: "
|
|
@@ -328,7 +120,7 @@ void Clustering::train_encoded(
|
|
|
328
120
|
idx_t(k) * min_points_per_centroid);
|
|
329
121
|
}
|
|
330
122
|
|
|
331
|
-
if (nx == k) {
|
|
123
|
+
if (static_cast<size_t>(nx) == k) {
|
|
332
124
|
// this is a corner case, just copy training set to clusters
|
|
333
125
|
if (verbose) {
|
|
334
126
|
printf("Number of training points (%" PRId64
|
|
@@ -397,7 +189,7 @@ void Clustering::train_encoded(
|
|
|
397
189
|
t0 = getmillisecs();
|
|
398
190
|
|
|
399
191
|
// initialize seed
|
|
400
|
-
const uint64_t actual_seed = get_actual_rng_seed(seed);
|
|
192
|
+
const uint64_t actual_seed = detail::get_actual_rng_seed(seed);
|
|
401
193
|
|
|
402
194
|
// temporary buffer to decode vectors during the optimization
|
|
403
195
|
std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
|
|
@@ -407,19 +199,52 @@ void Clustering::train_encoded(
|
|
|
407
199
|
printf("Outer iteration %d / %d\n", redo, nredo);
|
|
408
200
|
}
|
|
409
201
|
|
|
410
|
-
// initialize
|
|
202
|
+
// initialize centroids using the selected method
|
|
411
203
|
centroids.resize(d * k);
|
|
412
|
-
std::vector<int> perm(nx);
|
|
413
204
|
|
|
414
|
-
|
|
205
|
+
size_t k_to_init = k - n_input_centroids;
|
|
206
|
+
if (k_to_init > 0) {
|
|
207
|
+
// Fast path for RANDOM initialization - preserves exact original
|
|
208
|
+
// behavior
|
|
209
|
+
if (init_method == ClusteringInitMethod::RANDOM) {
|
|
210
|
+
std::vector<int> perm(nx);
|
|
211
|
+
rand_perm(perm.data(), nx, actual_seed + 1 + redo * 15486557L);
|
|
212
|
+
for (size_t i = 0; i < k_to_init; i++) {
|
|
213
|
+
if (!codec) {
|
|
214
|
+
memcpy(centroids.data() + (n_input_centroids + i) * d,
|
|
215
|
+
x + perm[n_input_centroids + i] * line_size,
|
|
216
|
+
line_size);
|
|
217
|
+
} else {
|
|
218
|
+
codec->sa_decode(
|
|
219
|
+
1,
|
|
220
|
+
x + perm[n_input_centroids + i] * line_size,
|
|
221
|
+
centroids.data() + (n_input_centroids + i) * d);
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
} else {
|
|
225
|
+
// For k-means++ and AFK-MC², we need all vectors decoded
|
|
226
|
+
const float* x_float = nullptr;
|
|
227
|
+
std::vector<float> x_decoded;
|
|
415
228
|
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
229
|
+
if (!codec) {
|
|
230
|
+
x_float = reinterpret_cast<const float*>(x);
|
|
231
|
+
} else {
|
|
232
|
+
// Decode all vectors for initialization
|
|
233
|
+
x_decoded.resize(nx * d);
|
|
234
|
+
codec->sa_decode(nx, x, x_decoded.data());
|
|
235
|
+
x_float = x_decoded.data();
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
ClusteringInitialization initializer(d, k_to_init);
|
|
239
|
+
initializer.method = init_method;
|
|
240
|
+
initializer.seed = actual_seed + 1 + redo * 15486557L;
|
|
241
|
+
initializer.afkmc2_chain_length = afkmc2_chain_length;
|
|
242
|
+
initializer.init_centroids(
|
|
243
|
+
nx,
|
|
244
|
+
x_float,
|
|
245
|
+
centroids.data() + n_input_centroids * d,
|
|
246
|
+
n_input_centroids,
|
|
247
|
+
n_input_centroids > 0 ? centroids.data() : nullptr);
|
|
423
248
|
}
|
|
424
249
|
}
|
|
425
250
|
|
|
@@ -453,9 +278,10 @@ void Clustering::train_encoded(
|
|
|
453
278
|
} else {
|
|
454
279
|
// search by blocks of decode_block_size vectors
|
|
455
280
|
size_t code_size = codec->sa_code_size();
|
|
456
|
-
for (size_t i0 = 0; i0 < nx;
|
|
281
|
+
for (size_t i0 = 0; i0 < static_cast<size_t>(nx);
|
|
282
|
+
i0 += decode_block_size) {
|
|
457
283
|
size_t i1 = i0 + decode_block_size;
|
|
458
|
-
if (i1 > nx) {
|
|
284
|
+
if (i1 > static_cast<size_t>(nx)) {
|
|
459
285
|
i1 = nx;
|
|
460
286
|
}
|
|
461
287
|
codec->sa_decode(
|
|
@@ -474,7 +300,7 @@ void Clustering::train_encoded(
|
|
|
474
300
|
|
|
475
301
|
// accumulate objective
|
|
476
302
|
obj = 0;
|
|
477
|
-
for (
|
|
303
|
+
for (idx_t j = 0; j < nx; j++) {
|
|
478
304
|
obj += dis[j];
|
|
479
305
|
}
|
|
480
306
|
|
|
@@ -482,7 +308,7 @@ void Clustering::train_encoded(
|
|
|
482
308
|
std::vector<float> hassign(k);
|
|
483
309
|
|
|
484
310
|
size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
|
|
485
|
-
compute_centroids(
|
|
311
|
+
detail::compute_centroids(
|
|
486
312
|
d,
|
|
487
313
|
k,
|
|
488
314
|
nx,
|
|
@@ -494,7 +320,7 @@ void Clustering::train_encoded(
|
|
|
494
320
|
hassign.data(),
|
|
495
321
|
centroids.data());
|
|
496
322
|
|
|
497
|
-
int nsplit = split_clusters(
|
|
323
|
+
int nsplit = detail::split_clusters(
|
|
498
324
|
d, k, nx, k_frozen, hassign.data(), centroids.data());
|
|
499
325
|
|
|
500
326
|
// collect statistics
|
|
@@ -502,7 +328,7 @@ void Clustering::train_encoded(
|
|
|
502
328
|
obj,
|
|
503
329
|
(getmillisecs() - t0) / 1000.0,
|
|
504
330
|
t_search_tot / 1000,
|
|
505
|
-
imbalance_factor(nx, k, assign.get()),
|
|
331
|
+
imbalance_factor(nx, static_cast<int>(k), assign.get()),
|
|
506
332
|
nsplit};
|
|
507
333
|
iteration_stats.push_back(stats);
|
|
508
334
|
|
|
@@ -529,6 +355,27 @@ void Clustering::train_encoded(
|
|
|
529
355
|
|
|
530
356
|
index.add(k, centroids.data());
|
|
531
357
|
InterruptCallback::check();
|
|
358
|
+
|
|
359
|
+
// Early stopping: if objective didn't change, we've converged.
|
|
360
|
+
// Safe to access iteration_stats[size - 2] because we push_back
|
|
361
|
+
// above, so size >= i + 1, and when i > 0 we have size >= 2.
|
|
362
|
+
if (i > 0) {
|
|
363
|
+
float prev_obj =
|
|
364
|
+
iteration_stats[iteration_stats.size() - 2].obj;
|
|
365
|
+
|
|
366
|
+
double change = (prev_obj == 0)
|
|
367
|
+
? std::numeric_limits<double>::max()
|
|
368
|
+
: std::abs(prev_obj - obj) / std::abs(prev_obj);
|
|
369
|
+
|
|
370
|
+
if (change >= 0 && change <= early_stop_threshold) {
|
|
371
|
+
if (verbose) {
|
|
372
|
+
printf("\n Converged at iteration %d: "
|
|
373
|
+
"objective did not change\n",
|
|
374
|
+
i);
|
|
375
|
+
}
|
|
376
|
+
break;
|
|
377
|
+
}
|
|
378
|
+
}
|
|
532
379
|
}
|
|
533
380
|
|
|
534
381
|
if (verbose) {
|
|
@@ -555,19 +402,19 @@ void Clustering::train_encoded(
|
|
|
555
402
|
}
|
|
556
403
|
}
|
|
557
404
|
|
|
558
|
-
Clustering1D::Clustering1D(int
|
|
405
|
+
Clustering1D::Clustering1D(int k_) : Clustering(1, k_) {}
|
|
559
406
|
|
|
560
|
-
Clustering1D::Clustering1D(int
|
|
561
|
-
: Clustering(1,
|
|
407
|
+
Clustering1D::Clustering1D(int k_, const ClusteringParameters& cp)
|
|
408
|
+
: Clustering(1, k_, cp) {}
|
|
562
409
|
|
|
563
410
|
void Clustering1D::train_exact(idx_t n, const float* x) {
|
|
564
411
|
const float* xt = x;
|
|
565
412
|
|
|
566
413
|
std::unique_ptr<uint8_t[]> del;
|
|
567
|
-
if (n > k * max_points_per_centroid) {
|
|
414
|
+
if (static_cast<size_t>(n) > k * max_points_per_centroid) {
|
|
568
415
|
uint8_t* x_new;
|
|
569
416
|
float* weights_new;
|
|
570
|
-
n = subsample_training_set(
|
|
417
|
+
n = detail::subsample_training_set(
|
|
571
418
|
*this,
|
|
572
419
|
n,
|
|
573
420
|
(uint8_t*)x,
|
|
@@ -592,7 +439,7 @@ float kmeans_clustering(
|
|
|
592
439
|
size_t k,
|
|
593
440
|
const float* x,
|
|
594
441
|
float* centroids) {
|
|
595
|
-
Clustering clus(d, k);
|
|
442
|
+
Clustering clus(static_cast<int>(d), static_cast<int>(k));
|
|
596
443
|
clus.verbose = d * n * k > (size_t(1) << 30);
|
|
597
444
|
// display logs if > 1Gflop per iteration
|
|
598
445
|
IndexFlatL2 index(d);
|
|
@@ -615,13 +462,14 @@ Index* ProgressiveDimIndexFactory::operator()(int dim) {
|
|
|
615
462
|
return new IndexFlatL2(dim);
|
|
616
463
|
}
|
|
617
464
|
|
|
618
|
-
ProgressiveDimClustering::ProgressiveDimClustering(int
|
|
465
|
+
ProgressiveDimClustering::ProgressiveDimClustering(int d_, int k_)
|
|
466
|
+
: d(d_), k(k_) {}
|
|
619
467
|
|
|
620
468
|
ProgressiveDimClustering::ProgressiveDimClustering(
|
|
621
|
-
int
|
|
622
|
-
int
|
|
469
|
+
int d_,
|
|
470
|
+
int k_,
|
|
623
471
|
const ProgressiveDimClusteringParameters& cp)
|
|
624
|
-
: ProgressiveDimClusteringParameters(cp), d(
|
|
472
|
+
: ProgressiveDimClusteringParameters(cp), d(d_), k(k_) {}
|
|
625
473
|
|
|
626
474
|
namespace {
|
|
627
475
|
|
|
@@ -642,7 +490,7 @@ void ProgressiveDimClustering::train(
|
|
|
642
490
|
ProgressiveDimIndexFactory& factory) {
|
|
643
491
|
int d_prev = 0;
|
|
644
492
|
|
|
645
|
-
PCAMatrix pca(d, d);
|
|
493
|
+
PCAMatrix pca(static_cast<int>(d), static_cast<int>(d));
|
|
646
494
|
|
|
647
495
|
std::vector<float> xbuf;
|
|
648
496
|
if (apply_pca) {
|
|
@@ -667,7 +515,7 @@ void ProgressiveDimClustering::train(
|
|
|
667
515
|
}
|
|
668
516
|
std::unique_ptr<Index> clustering_index(factory(di));
|
|
669
517
|
|
|
670
|
-
Clustering clus(di, k, *this);
|
|
518
|
+
Clustering clus(di, static_cast<int>(k), *this);
|
|
671
519
|
if (d_prev > 0) {
|
|
672
520
|
// copy warm-start centroids (padded with 0s)
|
|
673
521
|
clus.centroids.resize(k * di);
|
|
@@ -10,6 +10,7 @@
|
|
|
10
10
|
#ifndef FAISS_CLUSTERING_H
|
|
11
11
|
#define FAISS_CLUSTERING_H
|
|
12
12
|
#include <faiss/Index.h>
|
|
13
|
+
#include <faiss/impl/ClusteringInitialization.h>
|
|
13
14
|
|
|
14
15
|
#include <vector>
|
|
15
16
|
|
|
@@ -57,6 +58,23 @@ struct ClusteringParameters {
|
|
|
57
58
|
/// Whether to use splitmix64-based random number generator for subsampling,
|
|
58
59
|
/// which is faster, but may pick duplicate points.
|
|
59
60
|
bool use_faster_subsampling = false;
|
|
61
|
+
|
|
62
|
+
/// Initialization method for centroids.
|
|
63
|
+
/// RANDOM: uniform random sampling (default, current behavior)
|
|
64
|
+
/// KMEANS_PLUS_PLUS: k-means++ (O(nkd), better quality)
|
|
65
|
+
/// AFK_MC2: Assumption-Free K-MC² (O(nd) + O(mk²d), fast approximation)
|
|
66
|
+
ClusteringInitMethod init_method = ClusteringInitMethod::RANDOM;
|
|
67
|
+
|
|
68
|
+
/// Chain length for AFK-MC² initialization.
|
|
69
|
+
/// Only used when init_method = AFK_MC2.
|
|
70
|
+
/// Longer chains give better approximation but are slower.
|
|
71
|
+
uint16_t afkmc2_chain_length = 50;
|
|
72
|
+
|
|
73
|
+
/// Early stop threshold, the range is [0, 1].
|
|
74
|
+
/// The value of 0 implies a default Faiss behavior,
|
|
75
|
+
/// so the training process stops only if an error
|
|
76
|
+
/// is unchanged from the previous iteration.
|
|
77
|
+
double early_stop_threshold = 0.0;
|
|
60
78
|
};
|
|
61
79
|
|
|
62
80
|
struct ClusteringIterationStats {
|