faiss 0.6.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/faiss/extconf.rb +2 -1
- data/ext/faiss/{index_rb.cpp → index.cpp} +1 -1
- data/ext/faiss/index_binary.cpp +1 -1
- data/ext/faiss/kmeans.cpp +1 -1
- data/ext/faiss/pca_matrix.cpp +1 -1
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/ext/faiss/{utils_rb.cpp → utils.cpp} +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +93 -80
- data/vendor/faiss/faiss/Clustering.cpp +39 -240
- data/vendor/faiss/faiss/Clustering.h +6 -0
- data/vendor/faiss/faiss/IVFlib.cpp +41 -21
- data/vendor/faiss/faiss/Index.cpp +6 -5
- data/vendor/faiss/faiss/Index.h +5 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +49 -37
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +5 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +84 -92
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +87 -415
- data/vendor/faiss/faiss/IndexFastScan.cpp +72 -109
- data/vendor/faiss/faiss/IndexFastScan.h +25 -23
- data/vendor/faiss/faiss/IndexFlat.cpp +27 -20
- data/vendor/faiss/faiss/IndexFlat.h +21 -18
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +42 -19
- data/vendor/faiss/faiss/IndexHNSW.cpp +283 -145
- data/vendor/faiss/faiss/IndexHNSW.h +16 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +25 -21
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +465 -362
- data/vendor/faiss/faiss/IndexIVF.h +33 -12
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +96 -93
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +357 -238
- data/vendor/faiss/faiss/IndexIVFFastScan.h +42 -41
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +36 -68
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +53 -30
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +71 -843
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +151 -121
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +21 -17
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +26 -39
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +475 -476
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +248 -93
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +36 -19
- data/vendor/faiss/faiss/IndexLattice.cpp +13 -13
- data/vendor/faiss/faiss/IndexNNDescent.cpp +36 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +39 -23
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +31 -11
- data/vendor/faiss/faiss/IndexPQ.cpp +128 -221
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +11 -36
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +41 -277
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +183 -27
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -25
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +10 -9
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +14 -7
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +237 -149
- data/vendor/faiss/faiss/VectorTransform.h +16 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +48 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/factory_tools.cpp +5 -0
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +29 -25
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +16 -16
- data/vendor/faiss/faiss/impl/CodePacker.cpp +3 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +1 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +6 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +92 -317
- data/vendor/faiss/faiss/impl/HNSW.h +13 -34
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +82 -77
- data/vendor/faiss/faiss/impl/NNDescent.cpp +62 -25
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +38 -21
- data/vendor/faiss/faiss/impl/NSG.h +4 -4
- data/vendor/faiss/faiss/impl/Panorama.cpp +23 -6
- data/vendor/faiss/faiss/impl/Panorama.h +258 -87
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +46 -32
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +30 -23
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +55 -49
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +65 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +296 -283
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +26 -23
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +99 -75
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +52 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -1
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.h +7 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +8 -3
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +169 -2
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -356
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +282 -134
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1132 -45
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +95 -13
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +37 -23
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +6 -6
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx2.cpp → pq_code_distance-avx2.h} +9 -13
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx512.cpp → pq_code_distance-avx512.h} +9 -57
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +29 -111
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +238 -5
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -7
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +311 -477
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +3 -2
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +102 -11
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +27 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +3 -3
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +148 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +167 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +59 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +163 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +192 -8
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +12 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +100 -66
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +264 -172
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +270 -218
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +86 -18
- data/vendor/faiss/faiss/index_io.h +24 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +66 -16
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +13 -13
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +18 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +12 -3
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +7 -2
- data/vendor/faiss/faiss/utils/Heap.cpp +10 -10
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +47 -36
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +390 -560
- data/vendor/faiss/faiss/utils/distances.h +20 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +117 -37
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -177
- data/vendor/faiss/faiss/utils/extra_distances.cpp +9 -8
- data/vendor/faiss/faiss/utils/extra_distances.h +32 -6
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +2 -2
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +57 -536
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +5 -1
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +213 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +163 -10
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +250 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +7 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +2 -1
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +17 -5
- data/vendor/faiss/faiss/utils/simd_levels.h +93 -1
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +5 -5
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +119 -34
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -224
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -230
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -235
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -449
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -365
- /data/ext/faiss/{utils_rb.h → utils.h} +0 -0
|
@@ -0,0 +1,656 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <algorithm>
|
|
9
|
+
#include <cassert>
|
|
10
|
+
#include <cinttypes>
|
|
11
|
+
#include <cmath>
|
|
12
|
+
#include <cstdint>
|
|
13
|
+
#include <cstdio>
|
|
14
|
+
#include <cstring>
|
|
15
|
+
#include <limits>
|
|
16
|
+
#include <memory>
|
|
17
|
+
#include <vector>
|
|
18
|
+
|
|
19
|
+
#include <faiss/SuperKMeans.h>
|
|
20
|
+
#include <faiss/VectorTransform.h>
|
|
21
|
+
#include <faiss/impl/AdSampling.h>
|
|
22
|
+
#include <faiss/impl/ClusteringHelpers.h>
|
|
23
|
+
#include <faiss/impl/FaissAssert.h>
|
|
24
|
+
#include <faiss/impl/PdxLayout.h>
|
|
25
|
+
#include <faiss/impl/simd_dispatch.h>
|
|
26
|
+
#include <faiss/utils/distances.h>
|
|
27
|
+
#include <faiss/utils/random.h>
|
|
28
|
+
#include <faiss/utils/simd_impl/super_kmeans_kernels.h>
|
|
29
|
+
#include <faiss/utils/utils.h>
|
|
30
|
+
|
|
31
|
+
#ifndef FINTEGER
|
|
32
|
+
#define FINTEGER long
|
|
33
|
+
#endif
|
|
34
|
+
|
|
35
|
+
extern "C" {
|
|
36
|
+
|
|
37
|
+
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
|
|
38
|
+
|
|
39
|
+
int sgemm_(
|
|
40
|
+
const char* transa,
|
|
41
|
+
const char* transb,
|
|
42
|
+
FINTEGER* m,
|
|
43
|
+
FINTEGER* n,
|
|
44
|
+
FINTEGER* k,
|
|
45
|
+
const float* alpha,
|
|
46
|
+
const float* a,
|
|
47
|
+
FINTEGER* lda,
|
|
48
|
+
const float* b,
|
|
49
|
+
FINTEGER* ldb,
|
|
50
|
+
float* beta,
|
|
51
|
+
float* c,
|
|
52
|
+
FINTEGER* ldc);
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
namespace faiss {
|
|
56
|
+
|
|
57
|
+
namespace {
|
|
58
|
+
|
|
59
|
+
struct TrainState {
|
|
60
|
+
/// Orthogonal rotation. Train in rotated space (X_tilde = X * R);
|
|
61
|
+
/// un-rotate centroids before return.
|
|
62
|
+
faiss::RandomRotationMatrix R;
|
|
63
|
+
|
|
64
|
+
std::vector<float> X_tilde; // (n, d) row-major
|
|
65
|
+
int n = 0;
|
|
66
|
+
std::vector<float> Y_tilde; // (k, d) row-major
|
|
67
|
+
|
|
68
|
+
std::vector<int> assignments; // size n
|
|
69
|
+
std::vector<float> best_dists; // size n; tau per vector
|
|
70
|
+
|
|
71
|
+
/// ||X_tilde[i, 0:d_prime]||^2; recomputed when d_prime changes.
|
|
72
|
+
std::vector<float> x_norms_partial;
|
|
73
|
+
|
|
74
|
+
int d_prime = 0;
|
|
75
|
+
|
|
76
|
+
/// ADSampling threshold table; size d+1.
|
|
77
|
+
std::vector<float> ad_coeff;
|
|
78
|
+
|
|
79
|
+
/// PDX block layout for the trailing pruning sweep: block b covers
|
|
80
|
+
/// original dims [true_block_end[b] - block_dim[b], true_block_end[b]).
|
|
81
|
+
/// Recomputed when d_prime changes.
|
|
82
|
+
std::vector<int> block_dim;
|
|
83
|
+
std::vector<int> true_block_end;
|
|
84
|
+
|
|
85
|
+
/// Counter for the verbose-mode "low pruning" warning.
|
|
86
|
+
int low_pruning_streak = 0;
|
|
87
|
+
bool low_pruning_warning_printed = false;
|
|
88
|
+
|
|
89
|
+
explicit TrainState(int d) : R(d, d) {}
|
|
90
|
+
};
|
|
91
|
+
|
|
92
|
+
/// Rebuild state.block_dim and state.true_block_end from the current
|
|
93
|
+
/// state.d_prime and pdx_block_size. Call after any change to d_prime.
|
|
94
|
+
void rebuild_pdx_block_layout(int d, int pdx_block_size, TrainState& state) {
|
|
95
|
+
const int dp = state.d_prime;
|
|
96
|
+
const int d_trail = d - dp;
|
|
97
|
+
const int n_full_blocks = d_trail / pdx_block_size;
|
|
98
|
+
const int tail = d_trail % pdx_block_size;
|
|
99
|
+
const int n_blocks = n_full_blocks + (tail > 0 ? 1 : 0);
|
|
100
|
+
state.block_dim.assign(n_blocks, pdx_block_size);
|
|
101
|
+
state.true_block_end.resize(n_blocks);
|
|
102
|
+
if (n_blocks > 0) {
|
|
103
|
+
assert(!state.block_dim.empty());
|
|
104
|
+
assert(!state.true_block_end.empty());
|
|
105
|
+
for (int b = 0; b < n_full_blocks; ++b) {
|
|
106
|
+
state.true_block_end[b] = dp + (b + 1) * pdx_block_size;
|
|
107
|
+
}
|
|
108
|
+
if (tail > 0) {
|
|
109
|
+
state.block_dim[n_full_blocks] = tail;
|
|
110
|
+
state.true_block_end[n_full_blocks] = d;
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
struct IterScratch {
|
|
116
|
+
std::vector<float> partial_ip; // (bx_max, by_max) for the GEMM tile
|
|
117
|
+
std::vector<float> Y_pdx; // PDX-laid-out trailing block
|
|
118
|
+
std::vector<float> Y_trail; // row-major (k, d_trail) input to pdxify
|
|
119
|
+
std::vector<float> y_norms_partial; // ||Y_tilde[j, 0:dp]||^2
|
|
120
|
+
std::vector<int64_t> labels64; // size n; widened state.assignments
|
|
121
|
+
int prev_d_trail = -1;
|
|
122
|
+
};
|
|
123
|
+
|
|
124
|
+
/// Iter 0: full GEMM via knn_L2sqr (vanilla Lloyd's). Fills
|
|
125
|
+
/// state.assignments and state.best_dists. Returns objective.
|
|
126
|
+
double run_iter0_full_gemm(int d, int k, TrainState& state) {
|
|
127
|
+
std::vector<int64_t> labels(state.n);
|
|
128
|
+
std::vector<float> distances(state.n);
|
|
129
|
+
knn_L2sqr(
|
|
130
|
+
state.X_tilde.data(),
|
|
131
|
+
state.Y_tilde.data(),
|
|
132
|
+
d,
|
|
133
|
+
state.n,
|
|
134
|
+
k,
|
|
135
|
+
/*k=*/1,
|
|
136
|
+
distances.data(),
|
|
137
|
+
labels.data(),
|
|
138
|
+
/*y_norm2=*/nullptr);
|
|
139
|
+
|
|
140
|
+
assert(!state.assignments.empty());
|
|
141
|
+
assert(!state.best_dists.empty());
|
|
142
|
+
double objective = 0.0;
|
|
143
|
+
for (int i = 0; i < state.n; ++i) {
|
|
144
|
+
state.assignments[i] = static_cast<int>(labels[i]);
|
|
145
|
+
state.best_dists[i] = distances[i];
|
|
146
|
+
objective += distances[i];
|
|
147
|
+
}
|
|
148
|
+
return objective;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
/// Iter 1+: partial GEMM over [0, d_prime) + ADSampling progressive
|
|
152
|
+
/// pruning over the PDX-laid-out trailing block. Updates
|
|
153
|
+
/// state.assignments and state.best_dists. Writes total_pairs and
|
|
154
|
+
/// pruned_at_gemm. Returns objective.
|
|
155
|
+
double run_iter_pruned(
|
|
156
|
+
int d,
|
|
157
|
+
int k,
|
|
158
|
+
const SuperKMeansParameters& cp,
|
|
159
|
+
TrainState& state,
|
|
160
|
+
IterScratch& scratch,
|
|
161
|
+
int64_t& total_pairs,
|
|
162
|
+
int64_t& pruned_at_gemm) {
|
|
163
|
+
const int dp = state.d_prime;
|
|
164
|
+
assert(dp >= 1);
|
|
165
|
+
assert(!state.ad_coeff.empty());
|
|
166
|
+
assert(!scratch.partial_ip.empty());
|
|
167
|
+
assert(!scratch.y_norms_partial.empty());
|
|
168
|
+
const int d_trail = d - dp;
|
|
169
|
+
const int n_train = state.n;
|
|
170
|
+
assert(static_cast<int>(state.best_dists.size()) >= n_train);
|
|
171
|
+
assert(static_cast<int>(state.x_norms_partial.size()) >= n_train);
|
|
172
|
+
assert(static_cast<int>(state.assignments.size()) >= n_train);
|
|
173
|
+
|
|
174
|
+
if (d_trail != scratch.prev_d_trail) {
|
|
175
|
+
scratch.Y_pdx.resize(static_cast<size_t>(k) * d_trail);
|
|
176
|
+
scratch.Y_trail.resize(static_cast<size_t>(k) * d_trail);
|
|
177
|
+
scratch.prev_d_trail = d_trail;
|
|
178
|
+
}
|
|
179
|
+
for (int j = 0; j < k; ++j) {
|
|
180
|
+
std::memcpy(
|
|
181
|
+
scratch.Y_trail.data() + static_cast<size_t>(j) * d_trail,
|
|
182
|
+
state.Y_tilde.data() + static_cast<size_t>(j) * d + dp,
|
|
183
|
+
d_trail * sizeof(float));
|
|
184
|
+
}
|
|
185
|
+
detail::pdxify(
|
|
186
|
+
scratch.Y_trail.data(),
|
|
187
|
+
k,
|
|
188
|
+
d_trail,
|
|
189
|
+
cp.pdx_block_size,
|
|
190
|
+
scratch.Y_pdx.data());
|
|
191
|
+
|
|
192
|
+
detail::compute_partial_norms(
|
|
193
|
+
state.Y_tilde.data(), k, d, dp, scratch.y_norms_partial.data());
|
|
194
|
+
|
|
195
|
+
const int n_blocks = static_cast<int>(state.block_dim.size());
|
|
196
|
+
|
|
197
|
+
for (int xi = 0; xi < n_train; xi += cp.x_batch) {
|
|
198
|
+
const int bx = std::min(cp.x_batch, n_train - xi);
|
|
199
|
+
|
|
200
|
+
// Refresh tau: recompute full-d L2 distance to the previously
|
|
201
|
+
// assigned centroid. This is intentionally over all d dims (not
|
|
202
|
+
// just d_prime) because tau must be an exact distance for the
|
|
203
|
+
// chi-squared pruning bound to be valid. Cost is O(bx * d) per
|
|
204
|
+
// x-batch, amortized across the y-batch tiles that follow.
|
|
205
|
+
#pragma omp parallel for
|
|
206
|
+
for (int i = 0; i < bx; ++i) {
|
|
207
|
+
const int j_prev = state.assignments[xi + i];
|
|
208
|
+
const float* xrow =
|
|
209
|
+
state.X_tilde.data() + static_cast<size_t>(xi + i) * d;
|
|
210
|
+
const float* yrow =
|
|
211
|
+
state.Y_tilde.data() + static_cast<size_t>(j_prev) * d;
|
|
212
|
+
float tau = 0.0f;
|
|
213
|
+
for (int m = 0; m < d; ++m) {
|
|
214
|
+
const float diff = xrow[m] - yrow[m];
|
|
215
|
+
tau += diff * diff;
|
|
216
|
+
}
|
|
217
|
+
state.best_dists[xi + i] = tau;
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
for (int yj = 0; yj < k; yj += cp.y_batch) {
|
|
221
|
+
const int by = std::min(cp.y_batch, k - yj);
|
|
222
|
+
|
|
223
|
+
// GEMM phase: column-major sgemm computes
|
|
224
|
+
// partial_ip[i*by + j] = <X[xi+i, 0:dp], Y[yj+j, 0:dp]>.
|
|
225
|
+
{
|
|
226
|
+
FINTEGER M = by;
|
|
227
|
+
FINTEGER N_ = bx;
|
|
228
|
+
FINTEGER K_ = dp;
|
|
229
|
+
float alpha = 1.0f;
|
|
230
|
+
float beta = 0.0f;
|
|
231
|
+
FINTEGER lda_y = d;
|
|
232
|
+
FINTEGER lda_x = d;
|
|
233
|
+
FINTEGER ldc = by;
|
|
234
|
+
sgemm_("Transpose",
|
|
235
|
+
"Not transpose",
|
|
236
|
+
&M,
|
|
237
|
+
&N_,
|
|
238
|
+
&K_,
|
|
239
|
+
&alpha,
|
|
240
|
+
state.Y_tilde.data() + static_cast<size_t>(yj) * d,
|
|
241
|
+
&lda_y,
|
|
242
|
+
state.X_tilde.data() + static_cast<size_t>(xi) * d,
|
|
243
|
+
&lda_x,
|
|
244
|
+
&beta,
|
|
245
|
+
scratch.partial_ip.data(),
|
|
246
|
+
&ldc);
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
// One SIMD dispatch per (xi, yj) tile — block_l2<SL> below is
|
|
250
|
+
// a direct call (no per-call switch on SIMDConfig::level).
|
|
251
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
252
|
+
[[maybe_unused]] const int omp_chunk_local = cp.omp_chunk;
|
|
253
|
+
int64_t total_pairs_local = 0;
|
|
254
|
+
int64_t pruned_at_gemm_local = 0;
|
|
255
|
+
#pragma omp parallel for schedule(dynamic, omp_chunk_local) \
|
|
256
|
+
reduction(+ : total_pairs_local) reduction(+ : pruned_at_gemm_local)
|
|
257
|
+
for (int i = 0; i < bx; ++i) {
|
|
258
|
+
// tau is the best full-d distance found so far for this
|
|
259
|
+
// point; tightened as closer centroids are found.
|
|
260
|
+
float tau = state.best_dists[xi + i];
|
|
261
|
+
int best_j = state.assignments[xi + i];
|
|
262
|
+
const float xnp_i = state.x_norms_partial[xi + i];
|
|
263
|
+
const float* xrow = state.X_tilde.data() +
|
|
264
|
+
static_cast<size_t>(xi + i) * d;
|
|
265
|
+
|
|
266
|
+
for (int j = 0; j < by; ++j) {
|
|
267
|
+
++total_pairs_local;
|
|
268
|
+
|
|
269
|
+
// L2-from-IP; clamp to handle catastrophic
|
|
270
|
+
// cancellation when the true distance is ~0.
|
|
271
|
+
float pd = xnp_i + scratch.y_norms_partial[yj + j] -
|
|
272
|
+
2.0f *
|
|
273
|
+
scratch.partial_ip
|
|
274
|
+
[static_cast<size_t>(i) * by +
|
|
275
|
+
j];
|
|
276
|
+
if (pd < 0.0f) {
|
|
277
|
+
pd = 0.0f;
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
if (pd > state.ad_coeff[dp] * tau) {
|
|
281
|
+
++pruned_at_gemm_local;
|
|
282
|
+
continue;
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
// double accumulator mitigates float drift over many
|
|
286
|
+
// block additions.
|
|
287
|
+
double dist = pd;
|
|
288
|
+
bool keep = true;
|
|
289
|
+
|
|
290
|
+
// Progressive pruning across PDX blocks. Per block:
|
|
291
|
+
// stride = k * block_dim[b] floats, column-major
|
|
292
|
+
// across centroids.
|
|
293
|
+
size_t pdx_offset = 0;
|
|
294
|
+
for (int b = 0; b < n_blocks; ++b) {
|
|
295
|
+
const int n_in_block = state.block_dim.at(b);
|
|
296
|
+
const int true_end = state.true_block_end.at(b);
|
|
297
|
+
const float* xblk = xrow + (true_end - n_in_block);
|
|
298
|
+
const float* yblk = scratch.Y_pdx.data() +
|
|
299
|
+
pdx_offset +
|
|
300
|
+
static_cast<size_t>(yj + j) * n_in_block;
|
|
301
|
+
dist += faiss::detail::block_l2<SL>(
|
|
302
|
+
xblk, yblk, n_in_block);
|
|
303
|
+
pdx_offset += static_cast<size_t>(k) * n_in_block;
|
|
304
|
+
|
|
305
|
+
if (dist >
|
|
306
|
+
static_cast<double>(state.ad_coeff[true_end]) *
|
|
307
|
+
tau) {
|
|
308
|
+
keep = false;
|
|
309
|
+
break;
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
if (keep && dist < tau) {
|
|
314
|
+
tau = static_cast<float>(dist);
|
|
315
|
+
best_j = yj + j;
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
state.best_dists[xi + i] = tau;
|
|
320
|
+
state.assignments[xi + i] = best_j;
|
|
321
|
+
}
|
|
322
|
+
total_pairs += total_pairs_local;
|
|
323
|
+
pruned_at_gemm += pruned_at_gemm_local;
|
|
324
|
+
});
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
double objective = 0.0;
|
|
329
|
+
for (int i = 0; i < n_train; ++i) {
|
|
330
|
+
objective += state.best_dists[i];
|
|
331
|
+
}
|
|
332
|
+
return objective;
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
/// Post-iteration: update centroids and split empties. Returns nsplit.
|
|
336
|
+
int update_centroids_and_split(
|
|
337
|
+
int d,
|
|
338
|
+
int k,
|
|
339
|
+
TrainState& state,
|
|
340
|
+
IterScratch& scratch,
|
|
341
|
+
std::vector<float>& hassign) {
|
|
342
|
+
std::fill(hassign.begin(), hassign.end(), 0.0f);
|
|
343
|
+
assert(!scratch.labels64.empty());
|
|
344
|
+
assert(!state.assignments.empty());
|
|
345
|
+
for (int i = 0; i < state.n; ++i) {
|
|
346
|
+
scratch.labels64[i] = static_cast<int64_t>(state.assignments[i]);
|
|
347
|
+
}
|
|
348
|
+
detail::compute_centroids(
|
|
349
|
+
d,
|
|
350
|
+
k,
|
|
351
|
+
state.n,
|
|
352
|
+
/*k_frozen=*/0,
|
|
353
|
+
reinterpret_cast<const uint8_t*>(state.X_tilde.data()),
|
|
354
|
+
/*codec=*/nullptr,
|
|
355
|
+
scratch.labels64.data(),
|
|
356
|
+
/*weights=*/nullptr,
|
|
357
|
+
hassign.data(),
|
|
358
|
+
state.Y_tilde.data());
|
|
359
|
+
if (state.n <= k) {
|
|
360
|
+
return 0;
|
|
361
|
+
}
|
|
362
|
+
return detail::split_clusters(
|
|
363
|
+
d,
|
|
364
|
+
k,
|
|
365
|
+
state.n,
|
|
366
|
+
/*k_frozen=*/0,
|
|
367
|
+
hassign.data(),
|
|
368
|
+
state.Y_tilde.data());
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
/// Stay-in-band controller: nudge state.d_prime based on observed
|
|
372
|
+
/// pruning rate. Recomputes x_norms_partial if d_prime changed. Returns
|
|
373
|
+
/// the observed pruning rate (0 when there were no pairs).
|
|
374
|
+
float adapt_d_prime(
|
|
375
|
+
int d,
|
|
376
|
+
const SuperKMeansParameters& cp,
|
|
377
|
+
TrainState& state,
|
|
378
|
+
int64_t total_pairs,
|
|
379
|
+
int64_t pruned_at_gemm) {
|
|
380
|
+
if (total_pairs == 0) {
|
|
381
|
+
return 0.0f;
|
|
382
|
+
}
|
|
383
|
+
const float pruning_rate = static_cast<float>(pruned_at_gemm) /
|
|
384
|
+
static_cast<float>(total_pairs);
|
|
385
|
+
int new_dp = state.d_prime;
|
|
386
|
+
if (pruning_rate > cp.pruning_target_high) {
|
|
387
|
+
new_dp = static_cast<int>(
|
|
388
|
+
std::lround(state.d_prime * (1.0f - cp.d_prime_adjust)));
|
|
389
|
+
} else if (pruning_rate < cp.pruning_target_low) {
|
|
390
|
+
new_dp = static_cast<int>(
|
|
391
|
+
std::lround(state.d_prime * (1.0f + cp.d_prime_adjust)));
|
|
392
|
+
}
|
|
393
|
+
new_dp = std::max(cp.d_prime_min, new_dp);
|
|
394
|
+
new_dp = std::min(d / 2, new_dp);
|
|
395
|
+
if (new_dp != state.d_prime) {
|
|
396
|
+
state.d_prime = new_dp;
|
|
397
|
+
detail::compute_partial_norms(
|
|
398
|
+
state.X_tilde.data(),
|
|
399
|
+
state.n,
|
|
400
|
+
d,
|
|
401
|
+
state.d_prime,
|
|
402
|
+
state.x_norms_partial.data());
|
|
403
|
+
rebuild_pdx_block_layout(d, cp.pdx_block_size, state);
|
|
404
|
+
}
|
|
405
|
+
return pruning_rate;
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
/// Pre-loop setup: subsample, rotate, Forgy init, build ADSampling table,
|
|
409
|
+
/// allocate scratch. Returned `sampled_x_owner` keeps the subsampled buffer
|
|
410
|
+
/// alive when subsampling occurred (otherwise empty).
|
|
411
|
+
std::unique_ptr<uint8_t[]> setup_train_state(
|
|
412
|
+
TrainState& state,
|
|
413
|
+
IterScratch& scratch,
|
|
414
|
+
std::vector<float>& hassign,
|
|
415
|
+
const SuperKMeansParameters& cp,
|
|
416
|
+
int d,
|
|
417
|
+
int k,
|
|
418
|
+
idx_t n,
|
|
419
|
+
const float* x) {
|
|
420
|
+
const size_t line_size = sizeof(float) * static_cast<size_t>(d);
|
|
421
|
+
idx_t nx = n;
|
|
422
|
+
const uint8_t* x_bytes = reinterpret_cast<const uint8_t*>(x);
|
|
423
|
+
std::unique_ptr<uint8_t[]> sampled_x_owner;
|
|
424
|
+
if (static_cast<size_t>(nx) >
|
|
425
|
+
static_cast<size_t>(k) * cp.max_points_per_centroid) {
|
|
426
|
+
Clustering tmp_clus(d, k, cp);
|
|
427
|
+
uint8_t* x_new = nullptr;
|
|
428
|
+
float* w_unused = nullptr;
|
|
429
|
+
nx = detail::subsample_training_set(
|
|
430
|
+
tmp_clus,
|
|
431
|
+
nx,
|
|
432
|
+
x_bytes,
|
|
433
|
+
line_size,
|
|
434
|
+
/*weights=*/nullptr,
|
|
435
|
+
&x_new,
|
|
436
|
+
&w_unused);
|
|
437
|
+
FAISS_ASSERT(x_new != nullptr);
|
|
438
|
+
sampled_x_owner.reset(x_new);
|
|
439
|
+
x_bytes = x_new;
|
|
440
|
+
}
|
|
441
|
+
const float* x_sampled = reinterpret_cast<const float*>(x_bytes);
|
|
442
|
+
|
|
443
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
444
|
+
nx <= static_cast<idx_t>(std::numeric_limits<int>::max()),
|
|
445
|
+
"SuperKMeans: training set size exceeds INT_MAX after sampling");
|
|
446
|
+
state.n = static_cast<int>(nx);
|
|
447
|
+
|
|
448
|
+
state.R.init(cp.seed);
|
|
449
|
+
|
|
450
|
+
state.X_tilde.resize(static_cast<size_t>(state.n) * d);
|
|
451
|
+
state.R.apply_noalloc(state.n, x_sampled, state.X_tilde.data());
|
|
452
|
+
|
|
453
|
+
// Forgy init: pick k random rows from the rotated pool as initial
|
|
454
|
+
// centroids. These remain in rotated space; un-rotation happens
|
|
455
|
+
// after the iteration loop.
|
|
456
|
+
state.Y_tilde.resize(static_cast<size_t>(k) * d);
|
|
457
|
+
{
|
|
458
|
+
std::vector<int> perm(state.n);
|
|
459
|
+
rand_perm(perm.data(), state.n, static_cast<int64_t>(cp.seed) + 1);
|
|
460
|
+
for (int j = 0; j < k; ++j) {
|
|
461
|
+
std::memcpy(
|
|
462
|
+
state.Y_tilde.data() + static_cast<size_t>(j) * d,
|
|
463
|
+
state.X_tilde.data() + static_cast<size_t>(perm[j]) * d,
|
|
464
|
+
sizeof(float) * d);
|
|
465
|
+
}
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
state.d_prime =
|
|
469
|
+
std::max(cp.d_prime_min, static_cast<int>(d * cp.d_prime_fraction));
|
|
470
|
+
state.d_prime = std::min(state.d_prime, d / 2);
|
|
471
|
+
rebuild_pdx_block_layout(d, cp.pdx_block_size, state);
|
|
472
|
+
|
|
473
|
+
// Iter 1+ uses L2-from-IP only over [0, d_prime), so full ||X[i]||^2 is
|
|
474
|
+
// never read; iter 0 routes through knn_L2sqr which carries its own.
|
|
475
|
+
state.x_norms_partial.resize(state.n);
|
|
476
|
+
detail::compute_partial_norms(
|
|
477
|
+
state.X_tilde.data(),
|
|
478
|
+
state.n,
|
|
479
|
+
d,
|
|
480
|
+
state.d_prime,
|
|
481
|
+
state.x_norms_partial.data());
|
|
482
|
+
|
|
483
|
+
const double epsilon = static_cast<double>(cp.ad_epsilon_factor) / d;
|
|
484
|
+
state.ad_coeff = detail::precompute_ad_thresholds(d, epsilon);
|
|
485
|
+
FAISS_ASSERT_MSG(
|
|
486
|
+
state.ad_coeff.size() == static_cast<size_t>(d + 1),
|
|
487
|
+
"ad_coeff size mismatch");
|
|
488
|
+
|
|
489
|
+
state.assignments.assign(state.n, 0);
|
|
490
|
+
state.best_dists.assign(state.n, std::numeric_limits<float>::max());
|
|
491
|
+
|
|
492
|
+
hassign.assign(k, 0.0f);
|
|
493
|
+
|
|
494
|
+
const int by_max = std::min(cp.y_batch, k);
|
|
495
|
+
const int bx_max = std::min(cp.x_batch, state.n);
|
|
496
|
+
scratch.partial_ip.resize(static_cast<size_t>(bx_max) * by_max);
|
|
497
|
+
scratch.y_norms_partial.resize(k);
|
|
498
|
+
scratch.labels64.resize(state.n);
|
|
499
|
+
|
|
500
|
+
return sampled_x_owner;
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
/// Un-rotate centroids into output buffer. R orthogonal, so
|
|
504
|
+
/// reverse_transform applies R^T = R^-1.
|
|
505
|
+
void untransform_centroids(
|
|
506
|
+
std::vector<float>& centroids,
|
|
507
|
+
const RandomRotationMatrix& R,
|
|
508
|
+
int d,
|
|
509
|
+
int k,
|
|
510
|
+
const float* Y_tilde) {
|
|
511
|
+
centroids.resize(static_cast<size_t>(k) * d);
|
|
512
|
+
R.reverse_transform(k, Y_tilde, centroids.data());
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
} // namespace
|
|
516
|
+
|
|
517
|
+
SuperKMeans::SuperKMeans(int d, int k, const SuperKMeansParameters& cp_in)
|
|
518
|
+
: cp(cp_in), d(d), k(k) {
|
|
519
|
+
FAISS_THROW_IF_NOT_MSG(d > 0, "SuperKMeans: d must be positive");
|
|
520
|
+
FAISS_THROW_IF_NOT_MSG(k > 0, "SuperKMeans: k must be positive");
|
|
521
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
522
|
+
cp.d_prime_fraction > 0.0f && cp.d_prime_fraction <= 1.0f,
|
|
523
|
+
"SuperKMeans: d_prime_fraction must be in (0, 1]");
|
|
524
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
525
|
+
cp.d_prime_adjust >= 0.0f && cp.d_prime_adjust < 1.0f,
|
|
526
|
+
"SuperKMeans: d_prime_adjust must be in [0, 1)");
|
|
527
|
+
// d >= 2 * d_prime_min keeps d_prime in the chi-squared validity
|
|
528
|
+
// range after both clamping steps (floor at d_prime_min, ceiling
|
|
529
|
+
// at d/2). See AdSampling.h on the p >= 16 contract.
|
|
530
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
531
|
+
d >= 2 * cp.d_prime_min,
|
|
532
|
+
"SuperKMeans: d (%d) must be >= 2 * d_prime_min (%d)",
|
|
533
|
+
d,
|
|
534
|
+
cp.d_prime_min);
|
|
535
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
536
|
+
cp.d_prime_min >= 16,
|
|
537
|
+
"SuperKMeans: d_prime_min must be >= 16 (chi-squared validity floor)");
|
|
538
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
539
|
+
cp.pdx_block_size > 0, "SuperKMeans: pdx_block_size must be > 0");
|
|
540
|
+
FAISS_THROW_IF_NOT_MSG(cp.x_batch > 0, "SuperKMeans: x_batch must be > 0");
|
|
541
|
+
FAISS_THROW_IF_NOT_MSG(cp.y_batch > 0, "SuperKMeans: y_batch must be > 0");
|
|
542
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
543
|
+
cp.pruning_target_low > 0.0f &&
|
|
544
|
+
cp.pruning_target_low <= cp.pruning_target_high &&
|
|
545
|
+
cp.pruning_target_high < 1.0f,
|
|
546
|
+
"SuperKMeans: require 0 < pruning_target_low <= pruning_target_high < 1");
|
|
547
|
+
// epsilon = ad_epsilon_factor / d is the chi-squared significance
|
|
548
|
+
// level. precompute_ad_thresholds requires epsilon in (0, 1).
|
|
549
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
550
|
+
cp.ad_epsilon_factor > 0.0f &&
|
|
551
|
+
cp.ad_epsilon_factor < static_cast<float>(d),
|
|
552
|
+
"SuperKMeans: ad_epsilon_factor must be in (0, d) "
|
|
553
|
+
"so epsilon = factor/d is in (0,1)");
|
|
554
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
555
|
+
cp.omp_chunk > 0, "SuperKMeans: omp_chunk must be > 0");
|
|
556
|
+
}
|
|
557
|
+
|
|
558
|
+
void SuperKMeans::train(idx_t n, const float* x) {
|
|
559
|
+
FAISS_THROW_IF_NOT_MSG(n > 0, "SuperKMeans: n must be positive");
|
|
560
|
+
FAISS_THROW_IF_NOT_MSG(x != nullptr, "SuperKMeans: x must not be null");
|
|
561
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
562
|
+
n >= static_cast<idx_t>(k), "SuperKMeans: n must be >= k");
|
|
563
|
+
if (cp.check_input_data_for_NaNs) {
|
|
564
|
+
for (size_t i = 0; i < static_cast<size_t>(n) * d; i++) {
|
|
565
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
566
|
+
std::isfinite(x[i]),
|
|
567
|
+
"SuperKMeans: input contains NaN's or Inf's");
|
|
568
|
+
}
|
|
569
|
+
}
|
|
570
|
+
if (cp.verbose && n < static_cast<idx_t>(k) * cp.min_points_per_centroid) {
|
|
571
|
+
printf("WARNING: clustering %" PRId64
|
|
572
|
+
" points to %d centroids: please provide at least "
|
|
573
|
+
"%" PRId64 " training points\n",
|
|
574
|
+
n,
|
|
575
|
+
k,
|
|
576
|
+
static_cast<idx_t>(k) * cp.min_points_per_centroid);
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
TrainState state(d);
|
|
580
|
+
IterScratch scratch;
|
|
581
|
+
std::vector<float> hassign;
|
|
582
|
+
[[maybe_unused]] auto sampled_x_owner =
|
|
583
|
+
setup_train_state(state, scratch, hassign, cp, d, k, n, x);
|
|
584
|
+
|
|
585
|
+
iteration_stats.clear();
|
|
586
|
+
iteration_stats.reserve(cp.niter);
|
|
587
|
+
gemm_pruning_rates.clear();
|
|
588
|
+
gemm_pruning_rates.reserve(cp.niter);
|
|
589
|
+
|
|
590
|
+
const double t_train_start = getmillisecs();
|
|
591
|
+
|
|
592
|
+
for (int iter = 0; iter < cp.niter; ++iter) {
|
|
593
|
+
const double t_iter_start = getmillisecs();
|
|
594
|
+
double objective = 0.0;
|
|
595
|
+
int64_t total_pairs = 0;
|
|
596
|
+
int64_t pruned_at_gemm = 0;
|
|
597
|
+
|
|
598
|
+
if (iter == 0) {
|
|
599
|
+
objective = run_iter0_full_gemm(d, k, state);
|
|
600
|
+
} else {
|
|
601
|
+
objective = run_iter_pruned(
|
|
602
|
+
d, k, cp, state, scratch, total_pairs, pruned_at_gemm);
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
const int nsplit =
|
|
606
|
+
update_centroids_and_split(d, k, state, scratch, hassign);
|
|
607
|
+
const float pruning_rate = (iter == 0)
|
|
608
|
+
? 0.0f
|
|
609
|
+
: adapt_d_prime(d, cp, state, total_pairs, pruned_at_gemm);
|
|
610
|
+
|
|
611
|
+
ClusteringIterationStats stat{};
|
|
612
|
+
stat.obj = static_cast<float>(objective);
|
|
613
|
+
stat.time = (getmillisecs() - t_iter_start) / 1000.0;
|
|
614
|
+
stat.time_search = stat.time;
|
|
615
|
+
stat.imbalance_factor = std::numeric_limits<double>::quiet_NaN();
|
|
616
|
+
stat.nsplit = nsplit;
|
|
617
|
+
iteration_stats.push_back(stat);
|
|
618
|
+
gemm_pruning_rates.push_back(pruning_rate);
|
|
619
|
+
|
|
620
|
+
if (iter > 0) {
|
|
621
|
+
if (pruning_rate < 0.85f) {
|
|
622
|
+
state.low_pruning_streak++;
|
|
623
|
+
} else {
|
|
624
|
+
state.low_pruning_streak = 0;
|
|
625
|
+
}
|
|
626
|
+
if (cp.verbose && state.low_pruning_streak >= 3 &&
|
|
627
|
+
!state.low_pruning_warning_printed) {
|
|
628
|
+
fprintf(stderr,
|
|
629
|
+
"WARNING: SuperKMeans steady-state pruning < 0.85 for 3+ iters "
|
|
630
|
+
"(current=%.2f). Data may not be a good fit for ADSampling; "
|
|
631
|
+
"consider falling back to faiss::Clustering.\n",
|
|
632
|
+
pruning_rate);
|
|
633
|
+
state.low_pruning_warning_printed = true;
|
|
634
|
+
}
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
if (cp.verbose) {
|
|
638
|
+
printf(" Iter %d: obj=%g time=%.3fs prune=%.4f dp=%d nsplit=%d\n",
|
|
639
|
+
iter,
|
|
640
|
+
stat.obj,
|
|
641
|
+
stat.time,
|
|
642
|
+
pruning_rate,
|
|
643
|
+
state.d_prime,
|
|
644
|
+
nsplit);
|
|
645
|
+
}
|
|
646
|
+
}
|
|
647
|
+
|
|
648
|
+
if (cp.verbose) {
|
|
649
|
+
printf("Total training time: %.3fs\n",
|
|
650
|
+
(getmillisecs() - t_train_start) / 1000.0);
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
untransform_centroids(centroids, state.R, d, k, state.Y_tilde.data());
|
|
654
|
+
}
|
|
655
|
+
|
|
656
|
+
} // namespace faiss
|