faiss 0.6.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/faiss/extconf.rb +2 -1
- data/ext/faiss/{index_rb.cpp → index.cpp} +1 -1
- data/ext/faiss/index_binary.cpp +1 -1
- data/ext/faiss/kmeans.cpp +1 -1
- data/ext/faiss/pca_matrix.cpp +1 -1
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/ext/faiss/{utils_rb.cpp → utils.cpp} +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +93 -80
- data/vendor/faiss/faiss/Clustering.cpp +39 -240
- data/vendor/faiss/faiss/Clustering.h +6 -0
- data/vendor/faiss/faiss/IVFlib.cpp +41 -21
- data/vendor/faiss/faiss/Index.cpp +6 -5
- data/vendor/faiss/faiss/Index.h +5 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +49 -37
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexBinary.cpp +5 -3
- data/vendor/faiss/faiss/IndexBinary.h +4 -4
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +84 -92
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
- data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +87 -415
- data/vendor/faiss/faiss/IndexFastScan.cpp +72 -109
- data/vendor/faiss/faiss/IndexFastScan.h +25 -23
- data/vendor/faiss/faiss/IndexFlat.cpp +27 -20
- data/vendor/faiss/faiss/IndexFlat.h +21 -18
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +42 -19
- data/vendor/faiss/faiss/IndexHNSW.cpp +283 -145
- data/vendor/faiss/faiss/IndexHNSW.h +16 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +25 -21
- data/vendor/faiss/faiss/IndexIDMap.h +9 -7
- data/vendor/faiss/faiss/IndexIVF.cpp +465 -362
- data/vendor/faiss/faiss/IndexIVF.h +33 -12
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +96 -93
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +357 -238
- data/vendor/faiss/faiss/IndexIVFFastScan.h +42 -41
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +36 -68
- data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +53 -30
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +71 -843
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +151 -121
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +21 -17
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +26 -39
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +475 -476
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +248 -93
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +36 -19
- data/vendor/faiss/faiss/IndexLattice.cpp +13 -13
- data/vendor/faiss/faiss/IndexNNDescent.cpp +36 -21
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +39 -23
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +31 -11
- data/vendor/faiss/faiss/IndexPQ.cpp +128 -221
- data/vendor/faiss/faiss/IndexPQ.h +3 -2
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
- data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +11 -36
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -1
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +41 -277
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +183 -27
- data/vendor/faiss/faiss/IndexRefine.cpp +30 -25
- data/vendor/faiss/faiss/IndexRefine.h +4 -4
- data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
- data/vendor/faiss/faiss/IndexShards.cpp +10 -9
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
- data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
- data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
- data/vendor/faiss/faiss/MetaIndexes.h +1 -1
- data/vendor/faiss/faiss/MetricType.h +14 -7
- data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
- data/vendor/faiss/faiss/SuperKMeans.h +97 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +237 -149
- data/vendor/faiss/faiss/VectorTransform.h +16 -16
- data/vendor/faiss/faiss/build.cpp +23 -0
- data/vendor/faiss/faiss/build.h +15 -0
- data/vendor/faiss/faiss/clone_index.cpp +48 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
- data/vendor/faiss/faiss/factory_tools.cpp +5 -0
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
- data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
- data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
- data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
- data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
- data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
- data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
- data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
- data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
- data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +29 -25
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
- data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +16 -16
- data/vendor/faiss/faiss/impl/CodePacker.cpp +3 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +1 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
- data/vendor/faiss/faiss/impl/FaissAssert.h +6 -3
- data/vendor/faiss/faiss/impl/FaissException.h +50 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +92 -317
- data/vendor/faiss/faiss/impl/HNSW.h +13 -34
- data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
- data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
- data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +82 -77
- data/vendor/faiss/faiss/impl/NNDescent.cpp +62 -25
- data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +38 -21
- data/vendor/faiss/faiss/impl/NSG.h +4 -4
- data/vendor/faiss/faiss/impl/Panorama.cpp +23 -6
- data/vendor/faiss/faiss/impl/Panorama.h +258 -87
- data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
- data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +46 -32
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +30 -23
- data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +55 -49
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +65 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +296 -283
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +26 -23
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +99 -75
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +52 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -1
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
- data/vendor/faiss/faiss/impl/VisitedTable.h +7 -0
- data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
- data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
- data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
- data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
- data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
- data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
- data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
- data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
- data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +8 -3
- data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
- data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
- data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
- data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +169 -2
- data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
- data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
- data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
- data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -356
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
- data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
- data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +282 -134
- data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
- data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +1132 -45
- data/vendor/faiss/faiss/impl/index_read_utils.h +1 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +95 -13
- data/vendor/faiss/faiss/impl/io.cpp +6 -6
- data/vendor/faiss/faiss/impl/io_macros.h +33 -16
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +37 -23
- data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
- data/vendor/faiss/faiss/impl/mapped_io.cpp +6 -6
- data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx2.cpp → pq_code_distance-avx2.h} +9 -13
- data/vendor/faiss/faiss/impl/pq_code_distance/{pq_code_distance-avx512.cpp → pq_code_distance-avx512.h} +9 -57
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +29 -111
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +238 -5
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -7
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +311 -477
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +1 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +3 -2
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +102 -11
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +27 -1
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +3 -3
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +148 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +167 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +59 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +163 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +192 -8
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +12 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +100 -66
- data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +264 -172
- data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
- data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +270 -218
- data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
- data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
- data/vendor/faiss/faiss/impl/svs_io.h +8 -2
- data/vendor/faiss/faiss/index_factory.cpp +86 -18
- data/vendor/faiss/faiss/index_io.h +24 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +66 -16
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
- data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
- data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +13 -13
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
- data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +18 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +12 -3
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +7 -2
- data/vendor/faiss/faiss/utils/Heap.cpp +10 -10
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +47 -36
- data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
- data/vendor/faiss/faiss/utils/distances.cpp +390 -560
- data/vendor/faiss/faiss/utils/distances.h +20 -1
- data/vendor/faiss/faiss/utils/distances_dispatch.h +117 -37
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -177
- data/vendor/faiss/faiss/utils/extra_distances.cpp +9 -8
- data/vendor/faiss/faiss/utils/extra_distances.h +32 -6
- data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
- data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
- data/vendor/faiss/faiss/utils/hamming.h +92 -2
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
- data/vendor/faiss/faiss/utils/partitioning.h +31 -0
- data/vendor/faiss/faiss/utils/popcount.h +29 -0
- data/vendor/faiss/faiss/utils/pq_code_distance.h +2 -2
- data/vendor/faiss/faiss/utils/prefetch.h +2 -2
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
- data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
- data/vendor/faiss/faiss/utils/rabitq_simd.h +57 -536
- data/vendor/faiss/faiss/utils/random.cpp +6 -6
- data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +5 -1
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +213 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +163 -10
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +250 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +7 -4
- data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +2 -1
- data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
- data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
- data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +17 -5
- data/vendor/faiss/faiss/utils/simd_levels.h +93 -1
- data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
- data/vendor/faiss/faiss/utils/utils.cpp +5 -5
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +119 -34
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
- data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -224
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -230
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -235
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -449
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
- data/vendor/faiss/faiss/utils/simdlib.h +0 -42
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -365
- /data/ext/faiss/{utils_rb.h → utils.h} +0 -0
|
@@ -150,7 +150,9 @@ void VectorTransform::reverse_transform(idx_t, const float*, float*) const {
|
|
|
150
150
|
}
|
|
151
151
|
|
|
152
152
|
void VectorTransform::check_identical(const VectorTransform& other) const {
|
|
153
|
-
|
|
153
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
154
|
+
other.d_in == d_in && other.d_out == d_out,
|
|
155
|
+
"transforms must have matching d_in and d_out");
|
|
154
156
|
}
|
|
155
157
|
|
|
156
158
|
/*********************************************
|
|
@@ -158,9 +160,9 @@ void VectorTransform::check_identical(const VectorTransform& other) const {
|
|
|
158
160
|
*********************************************/
|
|
159
161
|
|
|
160
162
|
/// both d_in > d_out and d_out < d_in are supported
|
|
161
|
-
LinearTransform::LinearTransform(int
|
|
162
|
-
: VectorTransform(
|
|
163
|
-
have_bias(
|
|
163
|
+
LinearTransform::LinearTransform(int din, int dout, bool have_bias_in)
|
|
164
|
+
: VectorTransform(din, dout),
|
|
165
|
+
have_bias(have_bias_in),
|
|
164
166
|
is_orthonormal(false),
|
|
165
167
|
verbose(false) {
|
|
166
168
|
is_trained = false; // will be trained when A and b are initialized
|
|
@@ -171,21 +173,25 @@ void LinearTransform::apply_noalloc(idx_t n, const float* x, float* xt) const {
|
|
|
171
173
|
|
|
172
174
|
float c_factor;
|
|
173
175
|
if (have_bias) {
|
|
174
|
-
FAISS_THROW_IF_NOT_MSG(
|
|
176
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
177
|
+
b.size() == static_cast<size_t>(d_out), "Bias not initialized");
|
|
175
178
|
float* xi = xt;
|
|
176
|
-
for (
|
|
177
|
-
for (int j = 0; j < d_out; j++)
|
|
179
|
+
for (idx_t i = 0; i < n; i++) {
|
|
180
|
+
for (int j = 0; j < d_out; j++) {
|
|
178
181
|
*xi++ = b[j];
|
|
182
|
+
}
|
|
183
|
+
}
|
|
179
184
|
c_factor = 1.0;
|
|
180
185
|
} else {
|
|
181
186
|
c_factor = 0.0;
|
|
182
187
|
}
|
|
183
188
|
|
|
184
189
|
FAISS_THROW_IF_NOT_MSG(
|
|
185
|
-
A.size() == d_out * d_in,
|
|
190
|
+
A.size() == static_cast<size_t>(d_out) * d_in,
|
|
191
|
+
"Transformation matrix not initialized");
|
|
186
192
|
|
|
187
193
|
float one = 1;
|
|
188
|
-
FINTEGER nbiti = d_out, ni = n, di = d_in;
|
|
194
|
+
FINTEGER nbiti = d_out, ni = static_cast<FINTEGER>(n), di = d_in;
|
|
189
195
|
sgemm_("Transposed",
|
|
190
196
|
"Not transposed",
|
|
191
197
|
&nbiti,
|
|
@@ -203,20 +209,21 @@ void LinearTransform::apply_noalloc(idx_t n, const float* x, float* xt) const {
|
|
|
203
209
|
|
|
204
210
|
void LinearTransform::transform_transpose(idx_t n, const float* y, float* x)
|
|
205
211
|
const {
|
|
212
|
+
std::vector<float> y_bias_corrected;
|
|
206
213
|
if (have_bias) { // allocate buffer to store bias-corrected data
|
|
207
|
-
|
|
214
|
+
y_bias_corrected.resize(n * d_out);
|
|
208
215
|
const float* yr = y;
|
|
209
|
-
float* yw =
|
|
216
|
+
float* yw = y_bias_corrected.data();
|
|
210
217
|
for (idx_t i = 0; i < n; i++) {
|
|
211
218
|
for (int j = 0; j < d_out; j++) {
|
|
212
219
|
*yw++ = *yr++ - b[j];
|
|
213
220
|
}
|
|
214
221
|
}
|
|
215
|
-
y =
|
|
222
|
+
y = y_bias_corrected.data();
|
|
216
223
|
}
|
|
217
224
|
|
|
218
225
|
{
|
|
219
|
-
FINTEGER dii = d_in, doi = d_out, ni = n;
|
|
226
|
+
FINTEGER dii = d_in, doi = d_out, ni = static_cast<FINTEGER>(n);
|
|
220
227
|
float one = 1.0, zero = 0.0;
|
|
221
228
|
sgemm_("Not",
|
|
222
229
|
"Not",
|
|
@@ -232,9 +239,6 @@ void LinearTransform::transform_transpose(idx_t n, const float* y, float* x)
|
|
|
232
239
|
x,
|
|
233
240
|
&dii);
|
|
234
241
|
}
|
|
235
|
-
|
|
236
|
-
if (have_bias)
|
|
237
|
-
delete[] y;
|
|
238
242
|
}
|
|
239
243
|
|
|
240
244
|
void LinearTransform::set_is_orthonormal() {
|
|
@@ -249,7 +253,7 @@ void LinearTransform::set_is_orthonormal() {
|
|
|
249
253
|
}
|
|
250
254
|
|
|
251
255
|
double eps = 4e-5;
|
|
252
|
-
FAISS_ASSERT(A.size() >= d_out * d_in);
|
|
256
|
+
FAISS_ASSERT(A.size() >= static_cast<size_t>(d_out) * d_in);
|
|
253
257
|
{
|
|
254
258
|
std::vector<float> ATA(d_out * d_out);
|
|
255
259
|
FINTEGER dii = d_in, doi = d_out;
|
|
@@ -273,9 +277,10 @@ void LinearTransform::set_is_orthonormal() {
|
|
|
273
277
|
for (long i = 0; i < d_out; i++) {
|
|
274
278
|
for (long j = 0; j < d_out; j++) {
|
|
275
279
|
float v = ATA[i + j * d_out];
|
|
276
|
-
if (i == j)
|
|
280
|
+
if (i == j) {
|
|
277
281
|
v -= 1;
|
|
278
|
-
|
|
282
|
+
}
|
|
283
|
+
if (std::fabs(v) > eps) {
|
|
279
284
|
is_orthonormal = false;
|
|
280
285
|
}
|
|
281
286
|
}
|
|
@@ -298,10 +303,13 @@ void LinearTransform::print_if_verbose(
|
|
|
298
303
|
const std::vector<double>& mat,
|
|
299
304
|
int n,
|
|
300
305
|
int d) const {
|
|
301
|
-
if (!verbose)
|
|
306
|
+
if (!verbose) {
|
|
302
307
|
return;
|
|
308
|
+
}
|
|
303
309
|
printf("matrix %s: %d*%d [\n", name, n, d);
|
|
304
|
-
|
|
310
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
311
|
+
mat.size() >= static_cast<size_t>(n) * d,
|
|
312
|
+
"matrix size is too small for the given dimensions");
|
|
305
313
|
for (int i = 0; i < n; i++) {
|
|
306
314
|
for (int j = 0; j < d; j++) {
|
|
307
315
|
printf("%10.5g ", mat[i * d + j]);
|
|
@@ -314,8 +322,10 @@ void LinearTransform::print_if_verbose(
|
|
|
314
322
|
void LinearTransform::check_identical(const VectorTransform& other_in) const {
|
|
315
323
|
VectorTransform::check_identical(other_in);
|
|
316
324
|
auto other = dynamic_cast<const LinearTransform*>(&other_in);
|
|
317
|
-
|
|
318
|
-
|
|
325
|
+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to LinearTransform");
|
|
326
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
327
|
+
other->A == A && other->b == b,
|
|
328
|
+
"LinearTransform matrix A and bias vector b must match");
|
|
319
329
|
}
|
|
320
330
|
|
|
321
331
|
/*********************************************
|
|
@@ -388,7 +398,8 @@ static void generate_signs(
|
|
|
388
398
|
std::vector<float>& s1,
|
|
389
399
|
std::vector<float>& s2,
|
|
390
400
|
std::vector<float>& s3) {
|
|
391
|
-
|
|
401
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
402
|
+
p > 0, "number of Hadamard factors p must be positive");
|
|
392
403
|
SplitMix64RandomGenerator rng(seed);
|
|
393
404
|
s1.resize(p);
|
|
394
405
|
s2.resize(p);
|
|
@@ -424,9 +435,15 @@ void HadamardRotation::apply_noalloc(idx_t n, const float* x, float* xt) const {
|
|
|
424
435
|
|
|
425
436
|
size_t d = d_in;
|
|
426
437
|
size_t p = d_out;
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
438
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
439
|
+
signs1.size() == p,
|
|
440
|
+
"sign-flip vector 1 size must match output dimension");
|
|
441
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
442
|
+
signs2.size() == p,
|
|
443
|
+
"sign-flip vector 2 size must match output dimension");
|
|
444
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
445
|
+
signs3.size() == p,
|
|
446
|
+
"sign-flip vector 3 size must match output dimension");
|
|
430
447
|
|
|
431
448
|
// Each unnormalized FWHT scales norms by sqrt(p).
|
|
432
449
|
// Three rounds scale by p^(3/2). Normalize once at the end.
|
|
@@ -466,10 +483,14 @@ void HadamardRotation::apply_noalloc(idx_t n, const float* x, float* xt) const {
|
|
|
466
483
|
|
|
467
484
|
void HadamardRotation::check_identical(const VectorTransform& other) const {
|
|
468
485
|
auto* hr = dynamic_cast<const HadamardRotation*>(&other);
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
486
|
+
FAISS_THROW_IF_NOT_MSG(hr, "failed to cast to HadamardRotation");
|
|
487
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
488
|
+
d_in == hr->d_in, "HadamardRotation input dimensions must match");
|
|
489
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
490
|
+
d_out == hr->d_out,
|
|
491
|
+
"HadamardRotation output dimensions must match");
|
|
492
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
493
|
+
seed == hr->seed, "HadamardRotation seeds must match");
|
|
473
494
|
}
|
|
474
495
|
|
|
475
496
|
/*********************************************
|
|
@@ -477,13 +498,13 @@ void HadamardRotation::check_identical(const VectorTransform& other) const {
|
|
|
477
498
|
*********************************************/
|
|
478
499
|
|
|
479
500
|
PCAMatrix::PCAMatrix(
|
|
480
|
-
int
|
|
481
|
-
int
|
|
482
|
-
float
|
|
483
|
-
bool
|
|
484
|
-
: LinearTransform(
|
|
485
|
-
eigen_power(
|
|
486
|
-
random_rotation(
|
|
501
|
+
int din,
|
|
502
|
+
int dout,
|
|
503
|
+
float eigen_power_in,
|
|
504
|
+
bool random_rotation_in)
|
|
505
|
+
: LinearTransform(din, dout, true),
|
|
506
|
+
eigen_power(eigen_power_in),
|
|
507
|
+
random_rotation(random_rotation_in) {
|
|
487
508
|
is_trained = false;
|
|
488
509
|
max_points_per_d = 1000;
|
|
489
510
|
balanced_bins = 0;
|
|
@@ -497,7 +518,7 @@ namespace {
|
|
|
497
518
|
|
|
498
519
|
void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
|
|
499
520
|
{ // compute eigenvalues and vectors
|
|
500
|
-
FINTEGER info = 0, lwork = -1, di = d_in;
|
|
521
|
+
FINTEGER info = 0, lwork = -1, di = static_cast<FINTEGER>(d_in);
|
|
501
522
|
double workq;
|
|
502
523
|
|
|
503
524
|
dsyev_("Vectors as well",
|
|
@@ -509,8 +530,8 @@ void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
|
|
|
509
530
|
&workq,
|
|
510
531
|
&lwork,
|
|
511
532
|
&info);
|
|
512
|
-
lwork = FINTEGER(workq);
|
|
513
|
-
double
|
|
533
|
+
lwork = static_cast<FINTEGER>(workq);
|
|
534
|
+
std::vector<double> work(lwork);
|
|
514
535
|
|
|
515
536
|
dsyev_("Vectors as well",
|
|
516
537
|
"Upper",
|
|
@@ -518,12 +539,10 @@ void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
|
|
|
518
539
|
cov,
|
|
519
540
|
&di,
|
|
520
541
|
eigenvalues,
|
|
521
|
-
work,
|
|
542
|
+
work.data(),
|
|
522
543
|
&lwork,
|
|
523
544
|
&info);
|
|
524
545
|
|
|
525
|
-
delete[] work;
|
|
526
|
-
|
|
527
546
|
if (info != 0) {
|
|
528
547
|
fprintf(stderr,
|
|
529
548
|
"WARN ssyev info returns %d, "
|
|
@@ -534,15 +553,17 @@ void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
|
|
|
534
553
|
|
|
535
554
|
if (verbose && d_in <= 10) {
|
|
536
555
|
printf("info=%ld new eigvals=[", long(info));
|
|
537
|
-
for (
|
|
556
|
+
for (size_t j = 0; j < d_in; j++) {
|
|
538
557
|
printf("%g ", eigenvalues[j]);
|
|
558
|
+
}
|
|
539
559
|
printf("]\n");
|
|
540
560
|
|
|
541
561
|
double* ci = cov;
|
|
542
562
|
printf("eigenvecs=\n");
|
|
543
|
-
for (
|
|
544
|
-
for (
|
|
563
|
+
for (size_t i = 0; i < d_in; i++) {
|
|
564
|
+
for (size_t j = 0; j < d_in; j++) {
|
|
545
565
|
printf("%10.4g ", *ci++);
|
|
566
|
+
}
|
|
546
567
|
printf("\n");
|
|
547
568
|
}
|
|
548
569
|
}
|
|
@@ -550,12 +571,13 @@ void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
|
|
|
550
571
|
|
|
551
572
|
// revert order of eigenvectors & values
|
|
552
573
|
|
|
553
|
-
for (
|
|
574
|
+
for (size_t i = 0; i < d_in / 2; i++) {
|
|
554
575
|
std::swap(eigenvalues[i], eigenvalues[d_in - 1 - i]);
|
|
555
576
|
double* v1 = cov + i * d_in;
|
|
556
577
|
double* v2 = cov + (d_in - 1 - i) * d_in;
|
|
557
|
-
for (
|
|
578
|
+
for (size_t j = 0; j < d_in; j++) {
|
|
558
579
|
std::swap(v1[j], v2[j]);
|
|
580
|
+
}
|
|
559
581
|
}
|
|
560
582
|
}
|
|
561
583
|
|
|
@@ -571,17 +593,20 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
571
593
|
mean.resize(d_in, 0.0);
|
|
572
594
|
if (have_bias) { // we may want to skip the bias
|
|
573
595
|
const float* xi = x;
|
|
574
|
-
for (
|
|
575
|
-
for (int j = 0; j < d_in; j++)
|
|
596
|
+
for (idx_t i = 0; i < n; i++) {
|
|
597
|
+
for (int j = 0; j < d_in; j++) {
|
|
576
598
|
mean[j] += *xi++;
|
|
599
|
+
}
|
|
577
600
|
}
|
|
578
|
-
for (int j = 0; j < d_in; j++)
|
|
601
|
+
for (int j = 0; j < d_in; j++) {
|
|
579
602
|
mean[j] /= n;
|
|
603
|
+
}
|
|
580
604
|
}
|
|
581
605
|
if (verbose) {
|
|
582
606
|
printf("mean=[");
|
|
583
|
-
for (int j = 0; j < d_in; j++)
|
|
607
|
+
for (int j = 0; j < d_in; j++) {
|
|
584
608
|
printf("%g ", mean[j]);
|
|
609
|
+
}
|
|
585
610
|
printf("]\n");
|
|
586
611
|
}
|
|
587
612
|
|
|
@@ -592,12 +617,13 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
592
617
|
{ // initialize with mean * mean^T term
|
|
593
618
|
float* ci = cov;
|
|
594
619
|
for (int i = 0; i < d_in; i++) {
|
|
595
|
-
for (int j = 0; j < d_in; j++)
|
|
620
|
+
for (int j = 0; j < d_in; j++) {
|
|
596
621
|
*ci++ = -n * mean[i] * mean[j];
|
|
622
|
+
}
|
|
597
623
|
}
|
|
598
624
|
}
|
|
599
625
|
{
|
|
600
|
-
FINTEGER di = d_in, ni = n;
|
|
626
|
+
FINTEGER di = d_in, ni = static_cast<FINTEGER>(n);
|
|
601
627
|
float one = 1.0;
|
|
602
628
|
ssyrk_("Up",
|
|
603
629
|
"Non transposed",
|
|
@@ -614,38 +640,44 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
614
640
|
float* ci = cov;
|
|
615
641
|
printf("cov=\n");
|
|
616
642
|
for (int i = 0; i < d_in; i++) {
|
|
617
|
-
for (int j = 0; j < d_in; j++)
|
|
643
|
+
for (int j = 0; j < d_in; j++) {
|
|
618
644
|
printf("%10g ", *ci++);
|
|
645
|
+
}
|
|
619
646
|
printf("\n");
|
|
620
647
|
}
|
|
621
648
|
}
|
|
622
649
|
|
|
623
650
|
std::vector<double> covd(d_in * d_in);
|
|
624
|
-
for (size_t i = 0; i < d_in * d_in; i++)
|
|
651
|
+
for (size_t i = 0; i < d_in * d_in; i++) {
|
|
625
652
|
covd[i] = cov[i];
|
|
653
|
+
}
|
|
626
654
|
|
|
627
655
|
std::vector<double> eigenvaluesd(d_in);
|
|
628
656
|
|
|
629
657
|
eig(d_in, covd.data(), eigenvaluesd.data(), verbose);
|
|
630
658
|
|
|
631
|
-
for (size_t i = 0; i < d_in * d_in; i++)
|
|
659
|
+
for (size_t i = 0; i < d_in * d_in; i++) {
|
|
632
660
|
PCAMat[i] = covd[i];
|
|
661
|
+
}
|
|
633
662
|
eigenvalues.resize(d_in);
|
|
634
663
|
|
|
635
|
-
for (
|
|
664
|
+
for (int i = 0; i < d_in; i++) {
|
|
636
665
|
eigenvalues[i] = eigenvaluesd[i];
|
|
666
|
+
}
|
|
637
667
|
|
|
638
668
|
} else {
|
|
639
669
|
std::vector<float> xc(n * d_in);
|
|
640
670
|
|
|
641
|
-
for (
|
|
642
|
-
for (
|
|
671
|
+
for (idx_t i = 0; i < n; i++) {
|
|
672
|
+
for (int j = 0; j < d_in; j++) {
|
|
643
673
|
xc[i * d_in + j] = x[i * d_in + j] - mean[j];
|
|
674
|
+
}
|
|
675
|
+
}
|
|
644
676
|
|
|
645
677
|
// compute Gram matrix
|
|
646
678
|
std::vector<float> gram(n * n);
|
|
647
679
|
{
|
|
648
|
-
FINTEGER di = d_in, ni = n;
|
|
680
|
+
FINTEGER di = d_in, ni = static_cast<FINTEGER>(n);
|
|
649
681
|
float one = 1.0, zero = 0.0;
|
|
650
682
|
ssyrk_("Up",
|
|
651
683
|
"Transposed",
|
|
@@ -662,16 +694,18 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
662
694
|
if (verbose && d_in <= 10) {
|
|
663
695
|
float* ci = gram.data();
|
|
664
696
|
printf("gram=\n");
|
|
665
|
-
for (
|
|
666
|
-
for (
|
|
697
|
+
for (idx_t i = 0; i < n; i++) {
|
|
698
|
+
for (idx_t j = 0; j < n; j++) {
|
|
667
699
|
printf("%10g ", *ci++);
|
|
700
|
+
}
|
|
668
701
|
printf("\n");
|
|
669
702
|
}
|
|
670
703
|
}
|
|
671
704
|
|
|
672
705
|
std::vector<double> gramd(n * n);
|
|
673
|
-
for (size_t i = 0; i < n * n; i++)
|
|
706
|
+
for (size_t i = 0; i < n * n; i++) {
|
|
674
707
|
gramd[i] = gram[i];
|
|
708
|
+
}
|
|
675
709
|
|
|
676
710
|
std::vector<double> eigenvaluesd(n);
|
|
677
711
|
|
|
@@ -681,17 +715,19 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
681
715
|
|
|
682
716
|
PCAMat.resize(d_in * n);
|
|
683
717
|
|
|
684
|
-
for (size_t i = 0; i < n * n; i++)
|
|
718
|
+
for (size_t i = 0; i < n * n; i++) {
|
|
685
719
|
gram[i] = gramd[i];
|
|
720
|
+
}
|
|
686
721
|
|
|
687
722
|
eigenvalues.resize(d_in);
|
|
688
723
|
// fill in only the n first ones
|
|
689
|
-
for (
|
|
724
|
+
for (idx_t i = 0; i < n; i++) {
|
|
690
725
|
eigenvalues[i] = eigenvaluesd[i];
|
|
726
|
+
}
|
|
691
727
|
|
|
692
728
|
{ // compute PCAMat = x' * v
|
|
693
|
-
FINTEGER di = d_in, ni = n;
|
|
694
|
-
float one = 1.0;
|
|
729
|
+
FINTEGER di = d_in, ni = static_cast<FINTEGER>(n);
|
|
730
|
+
float one = 1.0, zero = 0.0;
|
|
695
731
|
|
|
696
732
|
sgemm_("Non",
|
|
697
733
|
"Non Trans",
|
|
@@ -703,7 +739,7 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
703
739
|
&di,
|
|
704
740
|
gram.data(),
|
|
705
741
|
&ni,
|
|
706
|
-
&
|
|
742
|
+
&zero,
|
|
707
743
|
PCAMat.data(),
|
|
708
744
|
&di);
|
|
709
745
|
}
|
|
@@ -711,9 +747,10 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
711
747
|
if (verbose && d_in <= 10) {
|
|
712
748
|
float* ci = PCAMat.data();
|
|
713
749
|
printf("PCAMat=\n");
|
|
714
|
-
for (
|
|
715
|
-
for (int j = 0; j < d_in; j++)
|
|
750
|
+
for (idx_t i = 0; i < n; i++) {
|
|
751
|
+
for (int j = 0; j < d_in; j++) {
|
|
716
752
|
printf("%10g ", *ci++);
|
|
753
|
+
}
|
|
717
754
|
printf("\n");
|
|
718
755
|
}
|
|
719
756
|
}
|
|
@@ -725,7 +762,9 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
|
|
|
725
762
|
}
|
|
726
763
|
|
|
727
764
|
void PCAMatrix::copy_from(const PCAMatrix& other) {
|
|
728
|
-
|
|
765
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
766
|
+
other.is_trained,
|
|
767
|
+
"source PCAMatrix must be trained before copying");
|
|
729
768
|
mean = other.mean;
|
|
730
769
|
eigenvalues = other.eigenvalues;
|
|
731
770
|
PCAMat = other.PCAMat;
|
|
@@ -735,7 +774,7 @@ void PCAMatrix::copy_from(const PCAMatrix& other) {
|
|
|
735
774
|
|
|
736
775
|
void PCAMatrix::prepare_Ab() {
|
|
737
776
|
FAISS_THROW_IF_NOT_FMT(
|
|
738
|
-
d_out * d_in <= PCAMat.size(),
|
|
777
|
+
static_cast<size_t>(d_out) * d_in <= PCAMat.size(),
|
|
739
778
|
"PCA matrix cannot output %d dimensions from %d ",
|
|
740
779
|
d_out,
|
|
741
780
|
d_in);
|
|
@@ -748,14 +787,17 @@ void PCAMatrix::prepare_Ab() {
|
|
|
748
787
|
if (eigen_power != 0) {
|
|
749
788
|
float* ai = A.data();
|
|
750
789
|
for (int i = 0; i < d_out; i++) {
|
|
751
|
-
float factor = pow(eigenvalues[i] + epsilon, eigen_power);
|
|
752
|
-
for (int j = 0; j < d_in; j++)
|
|
790
|
+
float factor = std::pow(eigenvalues[i] + epsilon, eigen_power);
|
|
791
|
+
for (int j = 0; j < d_in; j++) {
|
|
753
792
|
*ai++ *= factor;
|
|
793
|
+
}
|
|
754
794
|
}
|
|
755
795
|
}
|
|
756
796
|
|
|
757
797
|
if (balanced_bins != 0) {
|
|
758
|
-
|
|
798
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
799
|
+
d_out % balanced_bins == 0,
|
|
800
|
+
"output dimension must be divisible by balanced_bins");
|
|
759
801
|
int dsub = d_out / balanced_bins;
|
|
760
802
|
std::vector<float> Ain;
|
|
761
803
|
std::swap(A, Ain);
|
|
@@ -783,8 +825,9 @@ void PCAMatrix::prepare_Ab() {
|
|
|
783
825
|
|
|
784
826
|
if (verbose) {
|
|
785
827
|
printf(" bin accu=[");
|
|
786
|
-
for (int i = 0; i < balanced_bins; i++)
|
|
828
|
+
for (int i = 0; i < balanced_bins; i++) {
|
|
787
829
|
printf("%g ", accu[i]);
|
|
830
|
+
}
|
|
788
831
|
printf("]\n");
|
|
789
832
|
}
|
|
790
833
|
}
|
|
@@ -802,8 +845,9 @@ void PCAMatrix::prepare_Ab() {
|
|
|
802
845
|
if (eigen_power != 0) {
|
|
803
846
|
for (int i = 0; i < d_out; i++) {
|
|
804
847
|
float factor = pow(eigenvalues[i], eigen_power);
|
|
805
|
-
for (int j = 0; j < d_out; j++)
|
|
848
|
+
for (int j = 0; j < d_out; j++) {
|
|
806
849
|
rr.A[j * d_out + i] *= factor;
|
|
850
|
+
}
|
|
807
851
|
}
|
|
808
852
|
}
|
|
809
853
|
|
|
@@ -833,8 +877,9 @@ void PCAMatrix::prepare_Ab() {
|
|
|
833
877
|
|
|
834
878
|
for (int i = 0; i < d_out; i++) {
|
|
835
879
|
float accu = 0;
|
|
836
|
-
for (int j = 0; j < d_in; j++)
|
|
880
|
+
for (int j = 0; j < d_in; j++) {
|
|
837
881
|
accu -= mean[j] * A[j + i * d_in];
|
|
882
|
+
}
|
|
838
883
|
b[i] = accu;
|
|
839
884
|
}
|
|
840
885
|
|
|
@@ -858,7 +903,7 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
858
903
|
init_rotation.data(),
|
|
859
904
|
d * d * sizeof(rotation[0]));
|
|
860
905
|
} else {
|
|
861
|
-
RandomRotationMatrix rrot(d, d);
|
|
906
|
+
RandomRotationMatrix rrot(static_cast<int>(d), static_cast<int>(d));
|
|
862
907
|
rrot.init(seed);
|
|
863
908
|
for (size_t i = 0; i < d * d; i++) {
|
|
864
909
|
rotation[i] = rrot.A[i];
|
|
@@ -875,9 +920,11 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
875
920
|
std::vector<double> u(d * d), vt(d * d), singvals(d);
|
|
876
921
|
|
|
877
922
|
for (int i = 0; i < max_iter; i++) {
|
|
878
|
-
print_if_verbose(
|
|
923
|
+
print_if_verbose(
|
|
924
|
+
"rotation", rotation, static_cast<int>(d), static_cast<int>(d));
|
|
879
925
|
{ // rotated_data = np.dot(training_data, rotation)
|
|
880
|
-
FINTEGER di = d,
|
|
926
|
+
FINTEGER di = static_cast<FINTEGER>(d),
|
|
927
|
+
ni = static_cast<FINTEGER>(n);
|
|
881
928
|
double one = 1, zero = 0;
|
|
882
929
|
dgemm_("N",
|
|
883
930
|
"N",
|
|
@@ -893,14 +940,19 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
893
940
|
rotated_x.data(),
|
|
894
941
|
&di);
|
|
895
942
|
}
|
|
896
|
-
print_if_verbose(
|
|
943
|
+
print_if_verbose(
|
|
944
|
+
"rotated_x",
|
|
945
|
+
rotated_x,
|
|
946
|
+
static_cast<int>(n),
|
|
947
|
+
static_cast<int>(d));
|
|
897
948
|
// binarize
|
|
898
949
|
for (size_t j = 0; j < n * d; j++) {
|
|
899
950
|
rotated_x[j] = rotated_x[j] < 0 ? -1 : 1;
|
|
900
951
|
}
|
|
901
952
|
// covariance matrix
|
|
902
953
|
{ // rotated_data = np.dot(training_data, rotation)
|
|
903
|
-
FINTEGER di = d,
|
|
954
|
+
FINTEGER di = static_cast<FINTEGER>(d),
|
|
955
|
+
ni = static_cast<FINTEGER>(n);
|
|
904
956
|
double one = 1, zero = 0;
|
|
905
957
|
dgemm_("N",
|
|
906
958
|
"T",
|
|
@@ -916,10 +968,11 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
916
968
|
cov_mat.data(),
|
|
917
969
|
&di);
|
|
918
970
|
}
|
|
919
|
-
print_if_verbose(
|
|
971
|
+
print_if_verbose(
|
|
972
|
+
"cov_mat", cov_mat, static_cast<int>(d), static_cast<int>(d));
|
|
920
973
|
// SVD
|
|
921
974
|
{
|
|
922
|
-
FINTEGER di = d;
|
|
975
|
+
FINTEGER di = static_cast<FINTEGER>(d);
|
|
923
976
|
FINTEGER lwork = -1, info;
|
|
924
977
|
double lwork1;
|
|
925
978
|
|
|
@@ -939,8 +992,11 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
939
992
|
&lwork,
|
|
940
993
|
&info);
|
|
941
994
|
|
|
942
|
-
|
|
943
|
-
|
|
995
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
996
|
+
info == 0,
|
|
997
|
+
"LAPACK dgesvd workspace query returned info=%d",
|
|
998
|
+
int(info));
|
|
999
|
+
lwork = static_cast<FINTEGER>(lwork1);
|
|
944
1000
|
std::vector<double> work(lwork);
|
|
945
1001
|
dgesvd_("A",
|
|
946
1002
|
"A",
|
|
@@ -958,11 +1014,11 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
958
1014
|
&info);
|
|
959
1015
|
FAISS_THROW_IF_NOT_FMT(info == 0, "sgesvd returned info=%d", info);
|
|
960
1016
|
}
|
|
961
|
-
print_if_verbose("u", u, d, d);
|
|
962
|
-
print_if_verbose("vt", vt, d, d);
|
|
1017
|
+
print_if_verbose("u", u, static_cast<int>(d), static_cast<int>(d));
|
|
1018
|
+
print_if_verbose("vt", vt, static_cast<int>(d), static_cast<int>(d));
|
|
963
1019
|
// update rotation
|
|
964
1020
|
{
|
|
965
|
-
FINTEGER di = d;
|
|
1021
|
+
FINTEGER di = static_cast<FINTEGER>(d);
|
|
966
1022
|
double one = 1, zero = 0;
|
|
967
1023
|
dgemm_("N",
|
|
968
1024
|
"T",
|
|
@@ -978,7 +1034,11 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
978
1034
|
rotation.data(),
|
|
979
1035
|
&di);
|
|
980
1036
|
}
|
|
981
|
-
print_if_verbose(
|
|
1037
|
+
print_if_verbose(
|
|
1038
|
+
"final rot",
|
|
1039
|
+
rotation,
|
|
1040
|
+
static_cast<int>(d),
|
|
1041
|
+
static_cast<int>(d));
|
|
982
1042
|
}
|
|
983
1043
|
A.resize(d * d);
|
|
984
1044
|
for (size_t i = 0; i < d; i++) {
|
|
@@ -989,20 +1049,23 @@ void ITQMatrix::train(idx_t n, const float* xf) {
|
|
|
989
1049
|
is_trained = true;
|
|
990
1050
|
}
|
|
991
1051
|
|
|
992
|
-
ITQTransform::ITQTransform(int
|
|
993
|
-
: VectorTransform(
|
|
994
|
-
do_pca(
|
|
995
|
-
itq(
|
|
996
|
-
pca_then_itq(
|
|
997
|
-
if (!
|
|
998
|
-
|
|
1052
|
+
ITQTransform::ITQTransform(int din, int dout, bool do_pca_in)
|
|
1053
|
+
: VectorTransform(din, dout),
|
|
1054
|
+
do_pca(do_pca_in),
|
|
1055
|
+
itq(dout),
|
|
1056
|
+
pca_then_itq(din, dout, false) {
|
|
1057
|
+
if (!do_pca_in) {
|
|
1058
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1059
|
+
din == dout,
|
|
1060
|
+
"input and output dimensions must match when PCA is disabled");
|
|
999
1061
|
}
|
|
1000
1062
|
max_train_per_dim = 10;
|
|
1001
1063
|
is_trained = false;
|
|
1002
1064
|
}
|
|
1003
1065
|
|
|
1004
1066
|
void ITQTransform::train(idx_t n, const float* x_in) {
|
|
1005
|
-
|
|
1067
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1068
|
+
!is_trained, "ITQTransform has already been trained");
|
|
1006
1069
|
|
|
1007
1070
|
size_t max_train_points = std::max(d_in * max_train_per_dim, 32768);
|
|
1008
1071
|
const float* x =
|
|
@@ -1094,17 +1157,18 @@ void ITQTransform::apply_noalloc(idx_t n, const float* x, float* xt) const {
|
|
|
1094
1157
|
void ITQTransform::check_identical(const VectorTransform& other_in) const {
|
|
1095
1158
|
VectorTransform::check_identical(other_in);
|
|
1096
1159
|
auto other = dynamic_cast<const ITQTransform*>(&other_in);
|
|
1097
|
-
|
|
1160
|
+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to ITQTransform");
|
|
1098
1161
|
pca_then_itq.check_identical(other->pca_then_itq);
|
|
1099
|
-
|
|
1162
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1163
|
+
other->mean == mean, "ITQTransform mean vectors must match");
|
|
1100
1164
|
}
|
|
1101
1165
|
|
|
1102
1166
|
/*********************************************
|
|
1103
1167
|
* OPQMatrix
|
|
1104
1168
|
*********************************************/
|
|
1105
1169
|
|
|
1106
|
-
OPQMatrix::OPQMatrix(int d, int
|
|
1107
|
-
: LinearTransform(d, d2 == -1 ? d : d2, false), M(
|
|
1170
|
+
OPQMatrix::OPQMatrix(int d, int M_in, int d2)
|
|
1171
|
+
: LinearTransform(d, d2 == -1 ? d : d2, false), M(M_in) {
|
|
1108
1172
|
is_trained = false;
|
|
1109
1173
|
// OPQ is quite expensive to train, so set this right.
|
|
1110
1174
|
max_train_points = 256 * 256;
|
|
@@ -1150,17 +1214,20 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
|
|
|
1150
1214
|
{
|
|
1151
1215
|
std::vector<float> sum(d);
|
|
1152
1216
|
const float* xi = x;
|
|
1153
|
-
for (
|
|
1154
|
-
for (int j = 0; j < d_in; j++)
|
|
1217
|
+
for (idx_t i = 0; i < n; i++) {
|
|
1218
|
+
for (int j = 0; j < d_in; j++) {
|
|
1155
1219
|
sum[j] += *xi++;
|
|
1220
|
+
}
|
|
1156
1221
|
}
|
|
1157
|
-
for (
|
|
1222
|
+
for (size_t i = 0; i < d; i++) {
|
|
1158
1223
|
sum[i] /= n;
|
|
1224
|
+
}
|
|
1159
1225
|
float* yi = xtrain.data();
|
|
1160
1226
|
xi = x;
|
|
1161
|
-
for (
|
|
1162
|
-
for (int j = 0; j < d_in; j++)
|
|
1227
|
+
for (idx_t i = 0; i < n; i++) {
|
|
1228
|
+
for (int j = 0; j < d_in; j++) {
|
|
1163
1229
|
*yi++ = *xi++ - sum[j];
|
|
1230
|
+
}
|
|
1164
1231
|
yi += d - d_in;
|
|
1165
1232
|
}
|
|
1166
1233
|
}
|
|
@@ -1169,16 +1236,18 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
|
|
|
1169
1236
|
if (A.size() == 0) {
|
|
1170
1237
|
A.resize(d * d);
|
|
1171
1238
|
rotation = A.data();
|
|
1172
|
-
if (verbose)
|
|
1239
|
+
if (verbose) {
|
|
1173
1240
|
printf(" OPQMatrix::train: making random %zd*%zd rotation\n",
|
|
1174
1241
|
d,
|
|
1175
1242
|
d);
|
|
1243
|
+
}
|
|
1176
1244
|
float_randn(rotation, d * d, 1234);
|
|
1177
1245
|
matrix_qr(d, d, rotation);
|
|
1178
1246
|
// we use only the d * d2 upper part of the matrix
|
|
1179
1247
|
A.resize(d * d2);
|
|
1180
1248
|
} else {
|
|
1181
|
-
|
|
1249
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1250
|
+
A.size() == d * d2, "rotation matrix A has incorrect size");
|
|
1182
1251
|
rotation = A.data();
|
|
1183
1252
|
}
|
|
1184
1253
|
|
|
@@ -1192,7 +1261,9 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
|
|
|
1192
1261
|
double t0 = getmillisecs();
|
|
1193
1262
|
for (int iter = 0; iter < niter; iter++) {
|
|
1194
1263
|
{ // torch.mm(xtrain, rotation:t())
|
|
1195
|
-
FINTEGER di = d,
|
|
1264
|
+
FINTEGER di = static_cast<FINTEGER>(d),
|
|
1265
|
+
d2i = static_cast<FINTEGER>(d2),
|
|
1266
|
+
ni = static_cast<FINTEGER>(n);
|
|
1196
1267
|
float zero = 0, one = 1;
|
|
1197
1268
|
sgemm_("Transposed",
|
|
1198
1269
|
"Not transposed",
|
|
@@ -1227,18 +1298,21 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
|
|
|
1227
1298
|
|
|
1228
1299
|
float pq_err = fvec_L2sqr(pq_recons.data(), xproj.data(), n * d2) / n;
|
|
1229
1300
|
|
|
1230
|
-
if (verbose)
|
|
1301
|
+
if (verbose) {
|
|
1231
1302
|
printf(" Iteration %d (%d PQ iterations):"
|
|
1232
1303
|
"%.3f s, obj=%g\n",
|
|
1233
1304
|
iter,
|
|
1234
1305
|
pq_regular.cp.niter,
|
|
1235
1306
|
(getmillisecs() - t0) / 1000.0,
|
|
1236
1307
|
pq_err);
|
|
1308
|
+
}
|
|
1237
1309
|
|
|
1238
1310
|
{
|
|
1239
1311
|
float *u = tmp.data(), *vt = &tmp[d * d];
|
|
1240
1312
|
float* sing_val = &tmp[2 * d * d];
|
|
1241
|
-
FINTEGER di = d,
|
|
1313
|
+
FINTEGER di = static_cast<FINTEGER>(d),
|
|
1314
|
+
d2i = static_cast<FINTEGER>(d2),
|
|
1315
|
+
ni = static_cast<FINTEGER>(n);
|
|
1242
1316
|
float one = 1, zero = 0;
|
|
1243
1317
|
|
|
1244
1318
|
if (verbose) {
|
|
@@ -1277,7 +1351,11 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
|
|
|
1277
1351
|
&lwork,
|
|
1278
1352
|
&info);
|
|
1279
1353
|
|
|
1280
|
-
|
|
1354
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
1355
|
+
info == 0,
|
|
1356
|
+
"LAPACK sgesvd workspace query returned info=%d",
|
|
1357
|
+
int(info));
|
|
1358
|
+
lwork = static_cast<FINTEGER>(worksz);
|
|
1281
1359
|
std::vector<float> work(lwork);
|
|
1282
1360
|
// u and vt swapped
|
|
1283
1361
|
sgesvd_("All",
|
|
@@ -1313,9 +1391,10 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
|
|
|
1313
1391
|
}
|
|
1314
1392
|
|
|
1315
1393
|
// revert A matrix
|
|
1316
|
-
if (d > d_in) {
|
|
1317
|
-
for (long i = 0; i < d_out; i++)
|
|
1394
|
+
if (d > static_cast<size_t>(d_in)) {
|
|
1395
|
+
for (long i = 0; i < d_out; i++) {
|
|
1318
1396
|
memmove(&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
|
|
1397
|
+
}
|
|
1319
1398
|
A.resize(d_in * d_out);
|
|
1320
1399
|
}
|
|
1321
1400
|
|
|
@@ -1327,8 +1406,8 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
|
|
|
1327
1406
|
* NormalizationTransform
|
|
1328
1407
|
*********************************************/
|
|
1329
1408
|
|
|
1330
|
-
NormalizationTransform::NormalizationTransform(int d, float
|
|
1331
|
-
: VectorTransform(d, d), norm(
|
|
1409
|
+
NormalizationTransform::NormalizationTransform(int d, float norm_in)
|
|
1410
|
+
: VectorTransform(d, d), norm(norm_in) {}
|
|
1332
1411
|
|
|
1333
1412
|
NormalizationTransform::NormalizationTransform()
|
|
1334
1413
|
: VectorTransform(-1, -1), norm(-1) {}
|
|
@@ -1354,8 +1433,9 @@ void NormalizationTransform::check_identical(
|
|
|
1354
1433
|
const VectorTransform& other_in) const {
|
|
1355
1434
|
VectorTransform::check_identical(other_in);
|
|
1356
1435
|
auto other = dynamic_cast<const NormalizationTransform*>(&other_in);
|
|
1357
|
-
|
|
1358
|
-
|
|
1436
|
+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to NormalizationTransform");
|
|
1437
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1438
|
+
other->norm == norm, "normalization type must match");
|
|
1359
1439
|
}
|
|
1360
1440
|
|
|
1361
1441
|
/*********************************************
|
|
@@ -1370,12 +1450,12 @@ void CenteringTransform::train(idx_t n, const float* x) {
|
|
|
1370
1450
|
FAISS_THROW_IF_NOT_MSG(n > 0, "need at least one training vector");
|
|
1371
1451
|
mean.resize(d_in, 0);
|
|
1372
1452
|
for (idx_t i = 0; i < n; i++) {
|
|
1373
|
-
for (
|
|
1453
|
+
for (int j = 0; j < d_in; j++) {
|
|
1374
1454
|
mean[j] += *x++;
|
|
1375
1455
|
}
|
|
1376
1456
|
}
|
|
1377
1457
|
|
|
1378
|
-
for (
|
|
1458
|
+
for (int j = 0; j < d_in; j++) {
|
|
1379
1459
|
mean[j] /= n;
|
|
1380
1460
|
}
|
|
1381
1461
|
is_trained = true;
|
|
@@ -1383,10 +1463,11 @@ void CenteringTransform::train(idx_t n, const float* x) {
|
|
|
1383
1463
|
|
|
1384
1464
|
void CenteringTransform::apply_noalloc(idx_t n, const float* x, float* xt)
|
|
1385
1465
|
const {
|
|
1386
|
-
|
|
1466
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1467
|
+
is_trained, "CenteringTransform has not been trained");
|
|
1387
1468
|
|
|
1388
1469
|
for (idx_t i = 0; i < n; i++) {
|
|
1389
|
-
for (
|
|
1470
|
+
for (int j = 0; j < d_in; j++) {
|
|
1390
1471
|
*xt++ = *x++ - mean[j];
|
|
1391
1472
|
}
|
|
1392
1473
|
}
|
|
@@ -1394,10 +1475,11 @@ void CenteringTransform::apply_noalloc(idx_t n, const float* x, float* xt)
|
|
|
1394
1475
|
|
|
1395
1476
|
void CenteringTransform::reverse_transform(idx_t n, const float* xt, float* x)
|
|
1396
1477
|
const {
|
|
1397
|
-
|
|
1478
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1479
|
+
is_trained, "CenteringTransform has not been trained");
|
|
1398
1480
|
|
|
1399
1481
|
for (idx_t i = 0; i < n; i++) {
|
|
1400
|
-
for (
|
|
1482
|
+
for (int j = 0; j < d_in; j++) {
|
|
1401
1483
|
*x++ = *xt++ + mean[j];
|
|
1402
1484
|
}
|
|
1403
1485
|
}
|
|
@@ -1407,8 +1489,9 @@ void CenteringTransform::check_identical(
|
|
|
1407
1489
|
const VectorTransform& other_in) const {
|
|
1408
1490
|
VectorTransform::check_identical(other_in);
|
|
1409
1491
|
auto other = dynamic_cast<const CenteringTransform*>(&other_in);
|
|
1410
|
-
|
|
1411
|
-
|
|
1492
|
+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to CenteringTransform");
|
|
1493
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1494
|
+
other->mean == mean, "CenteringTransform mean vectors must match");
|
|
1412
1495
|
}
|
|
1413
1496
|
|
|
1414
1497
|
/*********************************************
|
|
@@ -1416,37 +1499,40 @@ void CenteringTransform::check_identical(
|
|
|
1416
1499
|
*********************************************/
|
|
1417
1500
|
|
|
1418
1501
|
RemapDimensionsTransform::RemapDimensionsTransform(
|
|
1419
|
-
int
|
|
1420
|
-
int
|
|
1502
|
+
int din,
|
|
1503
|
+
int dout,
|
|
1421
1504
|
const int* map_in)
|
|
1422
|
-
: VectorTransform(
|
|
1423
|
-
map.resize(
|
|
1424
|
-
for (int i = 0; i <
|
|
1505
|
+
: VectorTransform(din, dout) {
|
|
1506
|
+
map.resize(dout);
|
|
1507
|
+
for (int i = 0; i < dout; i++) {
|
|
1425
1508
|
map[i] = map_in[i];
|
|
1426
|
-
|
|
1509
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1510
|
+
map[i] == -1 || (map[i] >= 0 && map[i] < din),
|
|
1511
|
+
"map entries must be -1 (unused) or valid input dimension indices");
|
|
1427
1512
|
}
|
|
1428
1513
|
}
|
|
1429
1514
|
|
|
1430
1515
|
RemapDimensionsTransform::RemapDimensionsTransform(
|
|
1431
|
-
int
|
|
1432
|
-
int
|
|
1516
|
+
int din,
|
|
1517
|
+
int dout,
|
|
1433
1518
|
bool uniform)
|
|
1434
|
-
: VectorTransform(
|
|
1435
|
-
map.resize(
|
|
1519
|
+
: VectorTransform(din, dout) {
|
|
1520
|
+
map.resize(dout, -1);
|
|
1436
1521
|
|
|
1437
1522
|
if (uniform) {
|
|
1438
|
-
if (
|
|
1439
|
-
for (int i = 0; i <
|
|
1440
|
-
map[i *
|
|
1523
|
+
if (din < dout) {
|
|
1524
|
+
for (int i = 0; i < din; i++) {
|
|
1525
|
+
map[i * dout / din] = i;
|
|
1441
1526
|
}
|
|
1442
1527
|
} else {
|
|
1443
|
-
for (int i = 0; i <
|
|
1444
|
-
map[i] = i *
|
|
1528
|
+
for (int i = 0; i < dout; i++) {
|
|
1529
|
+
map[i] = i * din / dout;
|
|
1445
1530
|
}
|
|
1446
1531
|
}
|
|
1447
1532
|
} else {
|
|
1448
|
-
for (int i = 0; i <
|
|
1533
|
+
for (int i = 0; i < din && i < dout; i++) {
|
|
1449
1534
|
map[i] = i;
|
|
1535
|
+
}
|
|
1450
1536
|
}
|
|
1451
1537
|
}
|
|
1452
1538
|
|
|
@@ -1468,8 +1554,9 @@ void RemapDimensionsTransform::reverse_transform(
|
|
|
1468
1554
|
memset(x, 0, sizeof(*x) * n * d_in);
|
|
1469
1555
|
for (idx_t i = 0; i < n; i++) {
|
|
1470
1556
|
for (int j = 0; j < d_out; j++) {
|
|
1471
|
-
if (map[j] >= 0)
|
|
1557
|
+
if (map[j] >= 0) {
|
|
1472
1558
|
x[map[j]] = xt[j];
|
|
1559
|
+
}
|
|
1473
1560
|
}
|
|
1474
1561
|
x += d_in;
|
|
1475
1562
|
xt += d_out;
|
|
@@ -1480,6 +1567,7 @@ void RemapDimensionsTransform::check_identical(
|
|
|
1480
1567
|
const VectorTransform& other_in) const {
|
|
1481
1568
|
VectorTransform::check_identical(other_in);
|
|
1482
1569
|
auto other = dynamic_cast<const RemapDimensionsTransform*>(&other_in);
|
|
1483
|
-
|
|
1484
|
-
|
|
1570
|
+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to RemapDimensionsTransform");
|
|
1571
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1572
|
+
other->map == map, "RemapDimensionsTransform maps must match");
|
|
1485
1573
|
}
|