faiss 0.5.3 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (379) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +12 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +4 -4
  5. data/ext/faiss/index.cpp +63 -45
  6. data/ext/faiss/index_binary.cpp +37 -27
  7. data/ext/faiss/kmeans.cpp +9 -8
  8. data/ext/faiss/pca_matrix.cpp +9 -7
  9. data/ext/faiss/product_quantizer.cpp +13 -11
  10. data/ext/faiss/utils.cpp +4 -2
  11. data/ext/faiss/utils.h +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +214 -82
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +97 -249
  17. data/vendor/faiss/faiss/Clustering.h +18 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +67 -44
  19. data/vendor/faiss/faiss/Index.cpp +25 -12
  20. data/vendor/faiss/faiss/Index.h +26 -4
  21. data/vendor/faiss/faiss/Index2Layer.cpp +37 -53
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +68 -61
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +36 -34
  24. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +4 -1
  25. data/vendor/faiss/faiss/IndexBinary.cpp +6 -3
  26. data/vendor/faiss/faiss/IndexBinary.h +4 -4
  27. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +1 -1
  28. data/vendor/faiss/faiss/IndexBinaryFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -4
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +92 -95
  31. data/vendor/faiss/faiss/IndexBinaryHNSW.h +9 -3
  32. data/vendor/faiss/faiss/IndexBinaryHash.cpp +45 -236
  33. data/vendor/faiss/faiss/IndexBinaryHash.h +6 -6
  34. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +120 -414
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +105 -129
  36. data/vendor/faiss/faiss/IndexFastScan.h +35 -24
  37. data/vendor/faiss/faiss/IndexFlat.cpp +216 -152
  38. data/vendor/faiss/faiss/IndexFlat.h +32 -14
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +88 -41
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +299 -187
  42. data/vendor/faiss/faiss/IndexHNSW.h +30 -14
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +26 -22
  44. data/vendor/faiss/faiss/IndexIDMap.h +9 -7
  45. data/vendor/faiss/faiss/IndexIVF.cpp +535 -405
  46. data/vendor/faiss/faiss/IndexIVF.h +47 -16
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +77 -74
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +105 -99
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +6 -3
  50. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +379 -249
  51. data/vendor/faiss/faiss/IndexIVFFastScan.h +65 -60
  52. data/vendor/faiss/faiss/IndexIVFFlat.cpp +41 -124
  53. data/vendor/faiss/faiss/IndexIVFFlat.h +32 -0
  54. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +89 -138
  55. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +3 -1
  56. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +18 -15
  57. data/vendor/faiss/faiss/IndexIVFPQ.cpp +77 -907
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +184 -122
  59. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +3 -0
  60. data/vendor/faiss/faiss/IndexIVFPQR.cpp +23 -18
  61. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +59 -60
  62. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -3
  63. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +564 -416
  64. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +269 -111
  65. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +41 -127
  66. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  67. data/vendor/faiss/faiss/IndexLSH.cpp +44 -25
  68. data/vendor/faiss/faiss/IndexLattice.cpp +41 -36
  69. data/vendor/faiss/faiss/IndexNNDescent.cpp +37 -21
  70. data/vendor/faiss/faiss/IndexNNDescent.h +2 -2
  71. data/vendor/faiss/faiss/IndexNSG.cpp +40 -23
  72. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  73. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +32 -12
  74. data/vendor/faiss/faiss/IndexPQ.cpp +129 -213
  75. data/vendor/faiss/faiss/IndexPQ.h +3 -2
  76. data/vendor/faiss/faiss/IndexPQFastScan.cpp +20 -14
  77. data/vendor/faiss/faiss/IndexPQFastScan.h +3 -0
  78. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -18
  79. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  80. data/vendor/faiss/faiss/IndexRaBitQ.cpp +31 -43
  81. data/vendor/faiss/faiss/IndexRaBitQ.h +4 -3
  82. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +135 -317
  83. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +192 -34
  84. data/vendor/faiss/faiss/IndexRefine.cpp +30 -55
  85. data/vendor/faiss/faiss/IndexRefine.h +4 -4
  86. data/vendor/faiss/faiss/IndexReplicas.cpp +6 -6
  87. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +15 -14
  88. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +1 -1
  89. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +82 -14
  90. data/vendor/faiss/faiss/IndexShards.cpp +13 -13
  91. data/vendor/faiss/faiss/IndexShardsIVF.cpp +21 -15
  92. data/vendor/faiss/faiss/MatrixStats.cpp +5 -4
  93. data/vendor/faiss/faiss/MetaIndexes.cpp +19 -17
  94. data/vendor/faiss/faiss/MetaIndexes.h +1 -1
  95. data/vendor/faiss/faiss/MetricType.h +29 -6
  96. data/vendor/faiss/faiss/SuperKMeans.cpp +656 -0
  97. data/vendor/faiss/faiss/SuperKMeans.h +97 -0
  98. data/vendor/faiss/faiss/VectorTransform.cpp +349 -141
  99. data/vendor/faiss/faiss/VectorTransform.h +39 -16
  100. data/vendor/faiss/faiss/build.cpp +23 -0
  101. data/vendor/faiss/faiss/build.h +15 -0
  102. data/vendor/faiss/faiss/clone_index.cpp +55 -51
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +47 -47
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +11 -0
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +38 -38
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +11 -0
  107. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +6 -1
  108. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +6 -5
  110. data/vendor/faiss/faiss/gpu/GpuResources.h +1 -1
  111. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +9 -9
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +4 -3
  113. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +46 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +56 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +78 -1
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +72 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +23 -0
  118. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +1 -1
  119. data/vendor/faiss/faiss/gpu/utils/CuvsUtils.h +21 -10
  120. data/vendor/faiss/faiss/gpu_metal/GpuIndexFlat.h +22 -0
  121. data/vendor/faiss/faiss/gpu_metal/MetalCloner.h +35 -0
  122. data/vendor/faiss/faiss/gpu_metal/MetalFlatKernels.h +40 -0
  123. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +51 -0
  124. data/vendor/faiss/faiss/gpu_metal/MetalIndexFlat.h +65 -0
  125. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +66 -0
  126. data/vendor/faiss/faiss/gpu_metal/MetalResources.h +79 -0
  127. data/vendor/faiss/faiss/gpu_metal/StandardMetalResources.h +35 -0
  128. data/vendor/faiss/faiss/impl/AdSampling.cpp +103 -0
  129. data/vendor/faiss/faiss/impl/AdSampling.h +35 -0
  130. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +64 -34
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -0
  132. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +10 -9
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +3 -28
  134. data/vendor/faiss/faiss/impl/ClusteringHelpers.cpp +244 -0
  135. data/vendor/faiss/faiss/impl/ClusteringHelpers.h +94 -0
  136. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  137. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  138. data/vendor/faiss/faiss/impl/CodePacker.cpp +7 -3
  139. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  140. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  141. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  142. data/vendor/faiss/faiss/impl/DistanceComputer.h +8 -8
  143. data/vendor/faiss/faiss/impl/FaissAssert.h +64 -3
  144. data/vendor/faiss/faiss/impl/FaissException.h +50 -3
  145. data/vendor/faiss/faiss/impl/HNSW.cpp +117 -351
  146. data/vendor/faiss/faiss/impl/HNSW.h +21 -40
  147. data/vendor/faiss/faiss/impl/IDSelector.cpp +15 -11
  148. data/vendor/faiss/faiss/impl/IDSelector.h +8 -8
  149. data/vendor/faiss/faiss/impl/InvertedListScannerStats.h +26 -0
  150. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +114 -102
  151. data/vendor/faiss/faiss/impl/NNDescent.cpp +63 -26
  152. data/vendor/faiss/faiss/impl/NNDescent.h +6 -2
  153. data/vendor/faiss/faiss/impl/NSG.cpp +44 -26
  154. data/vendor/faiss/faiss/impl/NSG.h +20 -10
  155. data/vendor/faiss/faiss/impl/Panorama.cpp +76 -52
  156. data/vendor/faiss/faiss/impl/Panorama.h +265 -78
  157. data/vendor/faiss/faiss/impl/PdxLayout.cpp +93 -0
  158. data/vendor/faiss/faiss/impl/PdxLayout.h +41 -0
  159. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +62 -37
  160. data/vendor/faiss/faiss/impl/PolysemousTraining.h +3 -3
  161. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +35 -35
  162. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +21 -16
  163. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +99 -80
  164. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  165. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +135 -37
  166. data/vendor/faiss/faiss/impl/RaBitQUtils.h +148 -21
  167. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +298 -301
  168. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  169. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  170. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  171. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +40 -32
  172. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  173. data/vendor/faiss/faiss/impl/ResultHandler.h +218 -113
  174. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +119 -2362
  175. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +27 -3
  176. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +14 -11
  177. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  178. data/vendor/faiss/faiss/impl/VisitedTable.h +76 -0
  179. data/vendor/faiss/faiss/impl/approx_topk/approx_topk.h +276 -0
  180. data/vendor/faiss/faiss/impl/approx_topk/avx2.cpp +68 -0
  181. data/vendor/faiss/faiss/{utils → impl}/approx_topk/generic.h +15 -8
  182. data/vendor/faiss/faiss/impl/approx_topk/neon.cpp +68 -0
  183. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab-inl.h +169 -0
  184. data/vendor/faiss/faiss/impl/approx_topk/rq_beam_search_tab.h +117 -0
  185. data/vendor/faiss/faiss/impl/approx_topk/simdlib256-inl.h +146 -0
  186. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHNSW_impl.h +73 -0
  187. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryHash_impl.h +270 -0
  188. data/vendor/faiss/faiss/impl/binary_hamming/IndexBinaryIVF_impl.h +460 -0
  189. data/vendor/faiss/faiss/impl/binary_hamming/IndexIVFSpectralHash_impl.h +159 -0
  190. data/vendor/faiss/faiss/impl/binary_hamming/IndexPQ_impl.h +92 -0
  191. data/vendor/faiss/faiss/impl/binary_hamming/avx2.cpp +26 -0
  192. data/vendor/faiss/faiss/impl/binary_hamming/avx512.cpp +26 -0
  193. data/vendor/faiss/faiss/impl/binary_hamming/dispatch.h +143 -0
  194. data/vendor/faiss/faiss/impl/binary_hamming/neon.cpp +26 -0
  195. data/vendor/faiss/faiss/impl/binary_hamming/rvv.cpp +26 -0
  196. data/vendor/faiss/faiss/impl/expanded_scanners.h +163 -0
  197. data/vendor/faiss/faiss/impl/{FastScanDistancePostProcessing.h → fast_scan/FastScanDistancePostProcessing.h} +13 -6
  198. data/vendor/faiss/faiss/impl/{LookupTableScaler.h → fast_scan/LookupTableScaler.h} +16 -5
  199. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops.h +237 -0
  200. data/vendor/faiss/faiss/impl/fast_scan/accumulate_loops_512.h +185 -0
  201. data/vendor/faiss/faiss/impl/fast_scan/decompose_qbs.h +229 -0
  202. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +268 -0
  203. data/vendor/faiss/faiss/impl/{pq4_fast_scan.cpp → fast_scan/fast_scan.cpp} +176 -4
  204. data/vendor/faiss/faiss/impl/fast_scan/fast_scan.h +341 -0
  205. data/vendor/faiss/faiss/impl/fast_scan/impl-avx2.cpp +36 -0
  206. data/vendor/faiss/faiss/impl/fast_scan/impl-avx512.cpp +40 -0
  207. data/vendor/faiss/faiss/impl/fast_scan/impl-neon.cpp +120 -0
  208. data/vendor/faiss/faiss/impl/fast_scan/impl-riscv.cpp +104 -0
  209. data/vendor/faiss/faiss/impl/fast_scan/kernels_simd256.h +213 -0
  210. data/vendor/faiss/faiss/impl/{pq4_fast_scan_search_qbs.cpp → fast_scan/kernels_simd512.h} +26 -348
  211. data/vendor/faiss/faiss/impl/fast_scan/rabitq_dispatching.h +90 -0
  212. data/vendor/faiss/faiss/impl/fast_scan/rabitq_result_handler.h +108 -0
  213. data/vendor/faiss/faiss/impl/{simd_result_handlers.h → fast_scan/simd_result_handlers.h} +290 -142
  214. data/vendor/faiss/faiss/impl/hnsw/LockVector.cpp +54 -0
  215. data/vendor/faiss/faiss/impl/hnsw/LockVector.h +64 -0
  216. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +91 -0
  217. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -0
  218. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +104 -0
  219. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +111 -0
  220. data/vendor/faiss/faiss/impl/index_read.cpp +1950 -505
  221. data/vendor/faiss/faiss/impl/index_read_utils.h +1 -2
  222. data/vendor/faiss/faiss/impl/index_write.cpp +112 -21
  223. data/vendor/faiss/faiss/impl/io.cpp +6 -6
  224. data/vendor/faiss/faiss/impl/io_macros.h +33 -16
  225. data/vendor/faiss/faiss/impl/kmeans1d.cpp +10 -10
  226. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +81 -40
  227. data/vendor/faiss/faiss/impl/lattice_Zn.h +6 -6
  228. data/vendor/faiss/faiss/impl/mapped_io.cpp +15 -8
  229. data/vendor/faiss/faiss/impl/platform_macros.h +11 -4
  230. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQScanner_impl.h +549 -0
  231. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.cpp +245 -0
  232. data/vendor/faiss/faiss/impl/pq_code_distance/IVFPQ_QueryTables.h +105 -0
  233. data/vendor/faiss/faiss/impl/pq_code_distance/PQDistanceComputer_impl.h +106 -0
  234. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +21 -0
  235. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +21 -0
  236. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +21 -0
  237. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.h} +43 -220
  238. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.h} +25 -112
  239. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +59 -0
  240. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.h +96 -0
  241. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +256 -0
  242. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -146
  243. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +68 -0
  244. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +320 -483
  245. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  246. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  247. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +137 -0
  248. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +371 -0
  249. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +190 -0
  250. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  251. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +603 -0
  252. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +597 -0
  253. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +388 -0
  254. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +630 -0
  255. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +311 -0
  256. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +387 -0
  257. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +54 -0
  258. data/vendor/faiss/faiss/impl/simd_dispatch.h +173 -0
  259. data/vendor/faiss/faiss/impl/simdlib/simdlib.h +57 -0
  260. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_avx2.h +274 -171
  261. data/vendor/faiss/faiss/impl/simdlib/simdlib_avx512.h +414 -0
  262. data/vendor/faiss/faiss/impl/simdlib/simdlib_dispatch.h +44 -0
  263. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_emulated.h +231 -166
  264. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_neon.h +275 -217
  265. data/vendor/faiss/faiss/{utils → impl/simdlib}/simdlib_ppc64.h +201 -160
  266. data/vendor/faiss/faiss/impl/svs_io.cpp +12 -3
  267. data/vendor/faiss/faiss/impl/svs_io.h +8 -2
  268. data/vendor/faiss/faiss/index_factory.cpp +115 -28
  269. data/vendor/faiss/faiss/index_io.h +53 -3
  270. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +73 -20
  271. data/vendor/faiss/faiss/invlists/DirectMap.cpp +24 -14
  272. data/vendor/faiss/faiss/invlists/DirectMap.h +4 -3
  273. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +157 -73
  274. data/vendor/faiss/faiss/invlists/InvertedLists.h +86 -23
  275. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +4 -4
  276. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +14 -14
  277. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  278. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  279. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +2 -2
  280. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  281. data/vendor/faiss/faiss/svs/IndexSVSIVF.cpp +350 -0
  282. data/vendor/faiss/faiss/svs/IndexSVSIVF.h +128 -0
  283. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.cpp +40 -0
  284. data/vendor/faiss/faiss/svs/IndexSVSIVFLVQ.h +43 -0
  285. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.cpp +225 -0
  286. data/vendor/faiss/faiss/svs/IndexSVSIVFLeanVec.h +71 -0
  287. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +25 -1
  288. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +19 -2
  289. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +1 -1
  290. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +19 -2
  291. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +14 -0
  292. data/vendor/faiss/faiss/utils/Heap.cpp +56 -10
  293. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  294. data/vendor/faiss/faiss/utils/NeuralNet.cpp +54 -40
  295. data/vendor/faiss/faiss/utils/NeuralNet.h +1 -1
  296. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +10 -4
  297. data/vendor/faiss/faiss/utils/distances.cpp +507 -559
  298. data/vendor/faiss/faiss/utils/distances.h +118 -1
  299. data/vendor/faiss/faiss/utils/distances_dispatch.h +250 -0
  300. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +8 -7
  301. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +33 -14
  302. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +12 -1
  303. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +16 -293
  304. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based_neon.cpp +57 -0
  305. data/vendor/faiss/faiss/utils/distances_fused/simdlib_kernel-inl.h +290 -0
  306. data/vendor/faiss/faiss/utils/distances_simd.cpp +72 -3681
  307. data/vendor/faiss/faiss/utils/extra_distances.cpp +60 -102
  308. data/vendor/faiss/faiss/utils/extra_distances.h +79 -7
  309. data/vendor/faiss/faiss/utils/hamming-inl.h +13 -11
  310. data/vendor/faiss/faiss/utils/hamming.cpp +66 -517
  311. data/vendor/faiss/faiss/utils/hamming.h +92 -2
  312. data/vendor/faiss/faiss/utils/hamming_distance/common.h +287 -10
  313. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +15 -0
  314. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512.cpp +15 -0
  315. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx2.h +142 -0
  316. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +234 -0
  317. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-generic.h +368 -0
  318. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-neon.h +322 -0
  319. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-rvv.h +39 -0
  320. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer.h +146 -0
  321. data/vendor/faiss/faiss/utils/hamming_distance/hamming_impl.h +481 -0
  322. data/vendor/faiss/faiss/utils/hamming_distance/hamming_neon.cpp +15 -0
  323. data/vendor/faiss/faiss/utils/hamming_distance/hamming_rvv.cpp +15 -0
  324. data/vendor/faiss/faiss/utils/partitioning.cpp +66 -987
  325. data/vendor/faiss/faiss/utils/partitioning.h +31 -0
  326. data/vendor/faiss/faiss/utils/popcount.h +29 -0
  327. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  328. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  329. data/vendor/faiss/faiss/utils/quantize_lut.cpp +30 -30
  330. data/vendor/faiss/faiss/utils/quantize_lut.h +1 -1
  331. data/vendor/faiss/faiss/utils/rabitq_simd.h +124 -343
  332. data/vendor/faiss/faiss/utils/random.cpp +6 -6
  333. data/vendor/faiss/faiss/utils/simd_impl/IVFFlatScanner-inl.h +51 -0
  334. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +154 -0
  335. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +777 -0
  336. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +306 -0
  337. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1431 -0
  338. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1095 -0
  339. data/vendor/faiss/faiss/utils/simd_impl/distances_rvv.cpp +189 -0
  340. data/vendor/faiss/faiss/utils/simd_impl/distances_simdlib256.h +195 -0
  341. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +392 -0
  342. data/vendor/faiss/faiss/utils/{distances_fused/simdlib_based.h → simd_impl/exhaustive_L2sqr_blas_cmax.h} +5 -10
  343. data/vendor/faiss/faiss/utils/simd_impl/hamming_impl.h +481 -0
  344. data/vendor/faiss/faiss/utils/simd_impl/partitioning_avx2.cpp +14 -0
  345. data/vendor/faiss/faiss/utils/simd_impl/partitioning_neon.cpp +14 -0
  346. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +1085 -0
  347. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx2.cpp +355 -0
  348. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512.cpp +477 -0
  349. data/vendor/faiss/faiss/utils/simd_impl/rabitq_neon.cpp +55 -0
  350. data/vendor/faiss/faiss/utils/simd_impl/rabitq_rvv.cpp +55 -0
  351. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_dispatch.h +32 -0
  352. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels.h +43 -0
  353. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx2.cpp +57 -0
  354. data/vendor/faiss/faiss/utils/simd_impl/super_kmeans_kernels_avx512.cpp +45 -0
  355. data/vendor/faiss/faiss/utils/simd_levels.cpp +334 -0
  356. data/vendor/faiss/faiss/utils/simd_levels.h +183 -0
  357. data/vendor/faiss/faiss/utils/sorting.cpp +48 -36
  358. data/vendor/faiss/faiss/utils/utils.cpp +21 -14
  359. data/vendor/faiss/faiss/utils/utils.h +3 -3
  360. metadata +156 -42
  361. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +0 -29
  362. data/vendor/faiss/faiss/impl/RaBitQStats.h +0 -56
  363. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  364. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  365. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +0 -216
  366. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +0 -224
  367. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +0 -84
  368. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +0 -196
  369. data/vendor/faiss/faiss/utils/approx_topk/mode.h +0 -34
  370. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +0 -36
  371. data/vendor/faiss/faiss/utils/extra_distances-inl.h +0 -228
  372. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +0 -462
  373. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +0 -490
  374. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -450
  375. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +0 -87
  376. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +0 -524
  377. data/vendor/faiss/faiss/utils/simdlib.h +0 -42
  378. data/vendor/faiss/faiss/utils/simdlib_avx512.h +0 -296
  379. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -11,15 +11,14 @@
11
11
  #include <faiss/VectorTransform.h>
12
12
  #include <faiss/impl/AuxIndexStructures.h>
13
13
 
14
- #include <chrono>
15
14
  #include <cinttypes>
16
15
  #include <cmath>
17
16
  #include <cstdio>
18
17
  #include <cstring>
19
-
20
- #include <omp.h>
18
+ #include <limits>
21
19
 
22
20
  #include <faiss/IndexFlat.h>
21
+ #include <faiss/impl/ClusteringHelpers.h>
23
22
  #include <faiss/impl/FaissAssert.h>
24
23
  #include <faiss/impl/kmeans1d.h>
25
24
  #include <faiss/utils/distances.h>
@@ -28,10 +27,10 @@
28
27
 
29
28
  namespace faiss {
30
29
 
31
- Clustering::Clustering(int d, int k) : d(d), k(k) {}
30
+ Clustering::Clustering(int d_, int k_) : d(d_), k(k_) {}
32
31
 
33
- Clustering::Clustering(int d, int k, const ClusteringParameters& cp)
34
- : ClusteringParameters(cp), d(d), k(k) {}
32
+ Clustering::Clustering(int d_, int k_, const ClusteringParameters& cp)
33
+ : ClusteringParameters(cp), d(d_), k(k_) {}
35
34
 
36
35
  void Clustering::post_process_centroids() {
37
36
  if (spherical) {
@@ -58,213 +57,6 @@ void Clustering::train(
58
57
  weights);
59
58
  }
60
59
 
61
- namespace {
62
-
63
- uint64_t get_actual_rng_seed(const int seed) {
64
- return (seed >= 0)
65
- ? seed
66
- : static_cast<uint64_t>(std::chrono::high_resolution_clock::now()
67
- .time_since_epoch()
68
- .count());
69
- }
70
-
71
- idx_t subsample_training_set(
72
- const Clustering& clus,
73
- idx_t nx,
74
- const uint8_t* x,
75
- size_t line_size,
76
- const float* weights,
77
- uint8_t** x_out,
78
- float** weights_out) {
79
- if (clus.verbose) {
80
- printf("Sampling a subset of %zd / %" PRId64 " for training\n",
81
- clus.k * clus.max_points_per_centroid,
82
- nx);
83
- }
84
-
85
- const uint64_t actual_seed = get_actual_rng_seed(clus.seed);
86
-
87
- std::vector<int> perm;
88
- if (clus.use_faster_subsampling) {
89
- // use subsampling with splitmix64 rng
90
- SplitMix64RandomGenerator rng(actual_seed);
91
-
92
- const idx_t new_nx = clus.k * clus.max_points_per_centroid;
93
- perm.resize(new_nx);
94
- for (idx_t i = 0; i < new_nx; i++) {
95
- perm[i] = rng.rand_int(nx);
96
- }
97
- } else {
98
- // use subsampling with a default std rng
99
- perm.resize(nx);
100
- rand_perm(perm.data(), nx, actual_seed);
101
- }
102
-
103
- nx = clus.k * clus.max_points_per_centroid;
104
- uint8_t* x_new = new uint8_t[nx * line_size];
105
- *x_out = x_new;
106
-
107
- // might be worth omp-ing as well
108
- for (idx_t i = 0; i < nx; i++) {
109
- memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
110
- }
111
- if (weights) {
112
- float* weights_new = new float[nx];
113
- for (idx_t i = 0; i < nx; i++) {
114
- weights_new[i] = weights[perm[i]];
115
- }
116
- *weights_out = weights_new;
117
- } else {
118
- *weights_out = nullptr;
119
- }
120
- return nx;
121
- }
122
-
123
- /** compute centroids as (weighted) sum of training points
124
- *
125
- * @param x training vectors, size n * code_size (from codec)
126
- * @param codec how to decode the vectors (if NULL then cast to float*)
127
- * @param weights per-training vector weight, size n (or NULL)
128
- * @param assign nearest centroid for each training vector, size n
129
- * @param k_frozen do not update the k_frozen first centroids
130
- * @param centroids centroid vectors (output only), size k * d
131
- * @param hassign histogram of assignments per centroid (size k),
132
- * should be 0 on input
133
- *
134
- */
135
-
136
- void compute_centroids(
137
- size_t d,
138
- size_t k,
139
- size_t n,
140
- size_t k_frozen,
141
- const uint8_t* x,
142
- const Index* codec,
143
- const int64_t* assign,
144
- const float* weights,
145
- float* hassign,
146
- float* centroids) {
147
- k -= k_frozen;
148
- centroids += k_frozen * d;
149
-
150
- memset(centroids, 0, sizeof(*centroids) * d * k);
151
-
152
- size_t line_size = codec ? codec->sa_code_size() : d * sizeof(float);
153
-
154
- #pragma omp parallel
155
- {
156
- int nt = omp_get_num_threads();
157
- int rank = omp_get_thread_num();
158
-
159
- // this thread is taking care of centroids c0:c1
160
- size_t c0 = (k * rank) / nt;
161
- size_t c1 = (k * (rank + 1)) / nt;
162
- std::vector<float> decode_buffer(d);
163
-
164
- for (size_t i = 0; i < n; i++) {
165
- int64_t ci = assign[i];
166
- assert(ci >= 0 && ci < k + k_frozen);
167
- ci -= k_frozen;
168
- if (ci >= c0 && ci < c1) {
169
- float* c = centroids + ci * d;
170
- const float* xi;
171
- if (!codec) {
172
- xi = reinterpret_cast<const float*>(x + i * line_size);
173
- } else {
174
- float* xif = decode_buffer.data();
175
- codec->sa_decode(1, x + i * line_size, xif);
176
- xi = xif;
177
- }
178
- if (weights) {
179
- float w = weights[i];
180
- hassign[ci] += w;
181
- for (size_t j = 0; j < d; j++) {
182
- c[j] += xi[j] * w;
183
- }
184
- } else {
185
- hassign[ci] += 1.0;
186
- for (size_t j = 0; j < d; j++) {
187
- c[j] += xi[j];
188
- }
189
- }
190
- }
191
- }
192
- }
193
-
194
- #pragma omp parallel for
195
- for (idx_t ci = 0; ci < k; ci++) {
196
- if (hassign[ci] == 0) {
197
- continue;
198
- }
199
- float norm = 1 / hassign[ci];
200
- float* c = centroids + ci * d;
201
- for (size_t j = 0; j < d; j++) {
202
- c[j] *= norm;
203
- }
204
- }
205
- }
206
-
207
- // a bit above machine epsilon for float16
208
- #define EPS (1 / 1024.)
209
-
210
- /** Handle empty clusters by splitting larger ones.
211
- *
212
- * It works by slightly changing the centroids to make 2 clusters from
213
- * a single one. Takes the same arguments as compute_centroids.
214
- *
215
- * @return nb of splitting operations (larger is worse)
216
- */
217
- int split_clusters(
218
- size_t d,
219
- size_t k,
220
- size_t n,
221
- size_t k_frozen,
222
- float* hassign,
223
- float* centroids) {
224
- k -= k_frozen;
225
- centroids += k_frozen * d;
226
-
227
- /* Take care of void clusters */
228
- size_t nsplit = 0;
229
- RandomGenerator rng(1234);
230
- for (size_t ci = 0; ci < k; ci++) {
231
- if (hassign[ci] == 0) { /* need to redefine a centroid */
232
- size_t cj;
233
- for (cj = 0; true; cj = (cj + 1) % k) {
234
- /* probability to pick this cluster for split */
235
- float p = (hassign[cj] - 1.0) / (float)(n - k);
236
- float r = rng.rand_float();
237
- if (r < p) {
238
- break; /* found our cluster to be split */
239
- }
240
- }
241
- memcpy(centroids + ci * d,
242
- centroids + cj * d,
243
- sizeof(*centroids) * d);
244
-
245
- /* small symmetric perturbation */
246
- for (size_t j = 0; j < d; j++) {
247
- if (j % 2 == 0) {
248
- centroids[ci * d + j] *= 1 + EPS;
249
- centroids[cj * d + j] *= 1 - EPS;
250
- } else {
251
- centroids[ci * d + j] *= 1 - EPS;
252
- centroids[cj * d + j] *= 1 + EPS;
253
- }
254
- }
255
-
256
- /* assume even split of the cluster */
257
- hassign[ci] = hassign[cj] / 2;
258
- hassign[cj] -= hassign[ci];
259
- nsplit++;
260
- }
261
- }
262
-
263
- return nsplit;
264
- }
265
-
266
- } // namespace
267
-
268
60
  void Clustering::train_encoded(
269
61
  idx_t nx,
270
62
  const uint8_t* x_in,
@@ -272,7 +64,7 @@ void Clustering::train_encoded(
272
64
  Index& index,
273
65
  const float* weights) {
274
66
  FAISS_THROW_IF_NOT_FMT(
275
- nx >= k,
67
+ nx >= static_cast<idx_t>(k),
276
68
  "Number of training points (%" PRId64
277
69
  ") should be at least "
278
70
  "as large as number of clusters (%zd)",
@@ -280,13 +72,13 @@ void Clustering::train_encoded(
280
72
  k);
281
73
 
282
74
  FAISS_THROW_IF_NOT_FMT(
283
- (!codec || codec->d == d),
75
+ (!codec || static_cast<size_t>(codec->d) == d),
284
76
  "Codec dimension %d not the same as data dimension %d",
285
77
  int(codec->d),
286
78
  int(d));
287
79
 
288
80
  FAISS_THROW_IF_NOT_FMT(
289
- index.d == d,
81
+ static_cast<size_t>(index.d) == d,
290
82
  "Index dimension %d not the same as data dimension %d",
291
83
  int(index.d),
292
84
  int(d));
@@ -309,16 +101,16 @@ void Clustering::train_encoded(
309
101
  std::unique_ptr<float[]> del3;
310
102
  size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
311
103
 
312
- if (nx > k * max_points_per_centroid) {
104
+ if (static_cast<size_t>(nx) > k * max_points_per_centroid) {
313
105
  uint8_t* x_new;
314
106
  float* weights_new;
315
- nx = subsample_training_set(
107
+ nx = detail::subsample_training_set(
316
108
  *this, nx, x, line_size, weights, &x_new, &weights_new);
317
109
  del1.reset(x_new);
318
110
  x = x_new;
319
111
  del3.reset(weights_new);
320
112
  weights = weights_new;
321
- } else if (nx < k * min_points_per_centroid) {
113
+ } else if (static_cast<size_t>(nx) < k * min_points_per_centroid) {
322
114
  fprintf(stderr,
323
115
  "WARNING clustering %" PRId64
324
116
  " points to %zd centroids: "
@@ -328,7 +120,7 @@ void Clustering::train_encoded(
328
120
  idx_t(k) * min_points_per_centroid);
329
121
  }
330
122
 
331
- if (nx == k) {
123
+ if (static_cast<size_t>(nx) == k) {
332
124
  // this is a corner case, just copy training set to clusters
333
125
  if (verbose) {
334
126
  printf("Number of training points (%" PRId64
@@ -397,7 +189,7 @@ void Clustering::train_encoded(
397
189
  t0 = getmillisecs();
398
190
 
399
191
  // initialize seed
400
- const uint64_t actual_seed = get_actual_rng_seed(seed);
192
+ const uint64_t actual_seed = detail::get_actual_rng_seed(seed);
401
193
 
402
194
  // temporary buffer to decode vectors during the optimization
403
195
  std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
@@ -407,19 +199,52 @@ void Clustering::train_encoded(
407
199
  printf("Outer iteration %d / %d\n", redo, nredo);
408
200
  }
409
201
 
410
- // initialize (remaining) centroids with random points from the dataset
202
+ // initialize centroids using the selected method
411
203
  centroids.resize(d * k);
412
- std::vector<int> perm(nx);
413
204
 
414
- rand_perm(perm.data(), nx, actual_seed + 1 + redo * 15486557L);
205
+ size_t k_to_init = k - n_input_centroids;
206
+ if (k_to_init > 0) {
207
+ // Fast path for RANDOM initialization - preserves exact original
208
+ // behavior
209
+ if (init_method == ClusteringInitMethod::RANDOM) {
210
+ std::vector<int> perm(nx);
211
+ rand_perm(perm.data(), nx, actual_seed + 1 + redo * 15486557L);
212
+ for (size_t i = 0; i < k_to_init; i++) {
213
+ if (!codec) {
214
+ memcpy(centroids.data() + (n_input_centroids + i) * d,
215
+ x + perm[n_input_centroids + i] * line_size,
216
+ line_size);
217
+ } else {
218
+ codec->sa_decode(
219
+ 1,
220
+ x + perm[n_input_centroids + i] * line_size,
221
+ centroids.data() + (n_input_centroids + i) * d);
222
+ }
223
+ }
224
+ } else {
225
+ // For k-means++ and AFK-MC², we need all vectors decoded
226
+ const float* x_float = nullptr;
227
+ std::vector<float> x_decoded;
415
228
 
416
- if (!codec) {
417
- for (int i = n_input_centroids; i < k; i++) {
418
- memcpy(&centroids[i * d], x + perm[i] * line_size, line_size);
419
- }
420
- } else {
421
- for (int i = n_input_centroids; i < k; i++) {
422
- codec->sa_decode(1, x + perm[i] * line_size, &centroids[i * d]);
229
+ if (!codec) {
230
+ x_float = reinterpret_cast<const float*>(x);
231
+ } else {
232
+ // Decode all vectors for initialization
233
+ x_decoded.resize(nx * d);
234
+ codec->sa_decode(nx, x, x_decoded.data());
235
+ x_float = x_decoded.data();
236
+ }
237
+
238
+ ClusteringInitialization initializer(d, k_to_init);
239
+ initializer.method = init_method;
240
+ initializer.seed = actual_seed + 1 + redo * 15486557L;
241
+ initializer.afkmc2_chain_length = afkmc2_chain_length;
242
+ initializer.init_centroids(
243
+ nx,
244
+ x_float,
245
+ centroids.data() + n_input_centroids * d,
246
+ n_input_centroids,
247
+ n_input_centroids > 0 ? centroids.data() : nullptr);
423
248
  }
424
249
  }
425
250
 
@@ -453,9 +278,10 @@ void Clustering::train_encoded(
453
278
  } else {
454
279
  // search by blocks of decode_block_size vectors
455
280
  size_t code_size = codec->sa_code_size();
456
- for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
281
+ for (size_t i0 = 0; i0 < static_cast<size_t>(nx);
282
+ i0 += decode_block_size) {
457
283
  size_t i1 = i0 + decode_block_size;
458
- if (i1 > nx) {
284
+ if (i1 > static_cast<size_t>(nx)) {
459
285
  i1 = nx;
460
286
  }
461
287
  codec->sa_decode(
@@ -474,7 +300,7 @@ void Clustering::train_encoded(
474
300
 
475
301
  // accumulate objective
476
302
  obj = 0;
477
- for (int j = 0; j < nx; j++) {
303
+ for (idx_t j = 0; j < nx; j++) {
478
304
  obj += dis[j];
479
305
  }
480
306
 
@@ -482,7 +308,7 @@ void Clustering::train_encoded(
482
308
  std::vector<float> hassign(k);
483
309
 
484
310
  size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
485
- compute_centroids(
311
+ detail::compute_centroids(
486
312
  d,
487
313
  k,
488
314
  nx,
@@ -494,7 +320,7 @@ void Clustering::train_encoded(
494
320
  hassign.data(),
495
321
  centroids.data());
496
322
 
497
- int nsplit = split_clusters(
323
+ int nsplit = detail::split_clusters(
498
324
  d, k, nx, k_frozen, hassign.data(), centroids.data());
499
325
 
500
326
  // collect statistics
@@ -502,7 +328,7 @@ void Clustering::train_encoded(
502
328
  obj,
503
329
  (getmillisecs() - t0) / 1000.0,
504
330
  t_search_tot / 1000,
505
- imbalance_factor(nx, k, assign.get()),
331
+ imbalance_factor(nx, static_cast<int>(k), assign.get()),
506
332
  nsplit};
507
333
  iteration_stats.push_back(stats);
508
334
 
@@ -529,6 +355,27 @@ void Clustering::train_encoded(
529
355
 
530
356
  index.add(k, centroids.data());
531
357
  InterruptCallback::check();
358
+
359
+ // Early stopping: if objective didn't change, we've converged.
360
+ // Safe to access iteration_stats[size - 2] because we push_back
361
+ // above, so size >= i + 1, and when i > 0 we have size >= 2.
362
+ if (i > 0) {
363
+ float prev_obj =
364
+ iteration_stats[iteration_stats.size() - 2].obj;
365
+
366
+ double change = (prev_obj == 0)
367
+ ? std::numeric_limits<double>::max()
368
+ : std::abs(prev_obj - obj) / std::abs(prev_obj);
369
+
370
+ if (change >= 0 && change <= early_stop_threshold) {
371
+ if (verbose) {
372
+ printf("\n Converged at iteration %d: "
373
+ "objective did not change\n",
374
+ i);
375
+ }
376
+ break;
377
+ }
378
+ }
532
379
  }
533
380
 
534
381
  if (verbose) {
@@ -555,19 +402,19 @@ void Clustering::train_encoded(
555
402
  }
556
403
  }
557
404
 
558
- Clustering1D::Clustering1D(int k) : Clustering(1, k) {}
405
+ Clustering1D::Clustering1D(int k_) : Clustering(1, k_) {}
559
406
 
560
- Clustering1D::Clustering1D(int k, const ClusteringParameters& cp)
561
- : Clustering(1, k, cp) {}
407
+ Clustering1D::Clustering1D(int k_, const ClusteringParameters& cp)
408
+ : Clustering(1, k_, cp) {}
562
409
 
563
410
  void Clustering1D::train_exact(idx_t n, const float* x) {
564
411
  const float* xt = x;
565
412
 
566
413
  std::unique_ptr<uint8_t[]> del;
567
- if (n > k * max_points_per_centroid) {
414
+ if (static_cast<size_t>(n) > k * max_points_per_centroid) {
568
415
  uint8_t* x_new;
569
416
  float* weights_new;
570
- n = subsample_training_set(
417
+ n = detail::subsample_training_set(
571
418
  *this,
572
419
  n,
573
420
  (uint8_t*)x,
@@ -592,7 +439,7 @@ float kmeans_clustering(
592
439
  size_t k,
593
440
  const float* x,
594
441
  float* centroids) {
595
- Clustering clus(d, k);
442
+ Clustering clus(static_cast<int>(d), static_cast<int>(k));
596
443
  clus.verbose = d * n * k > (size_t(1) << 30);
597
444
  // display logs if > 1Gflop per iteration
598
445
  IndexFlatL2 index(d);
@@ -615,13 +462,14 @@ Index* ProgressiveDimIndexFactory::operator()(int dim) {
615
462
  return new IndexFlatL2(dim);
616
463
  }
617
464
 
618
- ProgressiveDimClustering::ProgressiveDimClustering(int d, int k) : d(d), k(k) {}
465
+ ProgressiveDimClustering::ProgressiveDimClustering(int d_, int k_)
466
+ : d(d_), k(k_) {}
619
467
 
620
468
  ProgressiveDimClustering::ProgressiveDimClustering(
621
- int d,
622
- int k,
469
+ int d_,
470
+ int k_,
623
471
  const ProgressiveDimClusteringParameters& cp)
624
- : ProgressiveDimClusteringParameters(cp), d(d), k(k) {}
472
+ : ProgressiveDimClusteringParameters(cp), d(d_), k(k_) {}
625
473
 
626
474
  namespace {
627
475
 
@@ -642,7 +490,7 @@ void ProgressiveDimClustering::train(
642
490
  ProgressiveDimIndexFactory& factory) {
643
491
  int d_prev = 0;
644
492
 
645
- PCAMatrix pca(d, d);
493
+ PCAMatrix pca(static_cast<int>(d), static_cast<int>(d));
646
494
 
647
495
  std::vector<float> xbuf;
648
496
  if (apply_pca) {
@@ -667,7 +515,7 @@ void ProgressiveDimClustering::train(
667
515
  }
668
516
  std::unique_ptr<Index> clustering_index(factory(di));
669
517
 
670
- Clustering clus(di, k, *this);
518
+ Clustering clus(di, static_cast<int>(k), *this);
671
519
  if (d_prev > 0) {
672
520
  // copy warm-start centroids (padded with 0s)
673
521
  clus.centroids.resize(k * di);
@@ -10,6 +10,7 @@
10
10
  #ifndef FAISS_CLUSTERING_H
11
11
  #define FAISS_CLUSTERING_H
12
12
  #include <faiss/Index.h>
13
+ #include <faiss/impl/ClusteringInitialization.h>
13
14
 
14
15
  #include <vector>
15
16
 
@@ -57,6 +58,23 @@ struct ClusteringParameters {
57
58
  /// Whether to use splitmix64-based random number generator for subsampling,
58
59
  /// which is faster, but may pick duplicate points.
59
60
  bool use_faster_subsampling = false;
61
+
62
+ /// Initialization method for centroids.
63
+ /// RANDOM: uniform random sampling (default, current behavior)
64
+ /// KMEANS_PLUS_PLUS: k-means++ (O(nkd), better quality)
65
+ /// AFK_MC2: Assumption-Free K-MC² (O(nd) + O(mk²d), fast approximation)
66
+ ClusteringInitMethod init_method = ClusteringInitMethod::RANDOM;
67
+
68
+ /// Chain length for AFK-MC² initialization.
69
+ /// Only used when init_method = AFK_MC2.
70
+ /// Longer chains give better approximation but are slower.
71
+ uint16_t afkmc2_chain_length = 50;
72
+
73
+ /// Early stop threshold, the range is [0, 1].
74
+ /// The value of 0 implies a default Faiss behavior,
75
+ /// so the training process stops only if an error
76
+ /// is unchanged from the previous iteration.
77
+ double early_stop_threshold = 0.0;
60
78
  };
61
79
 
62
80
  struct ClusteringIterationStats {