faiss 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (293) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/AutoTune.cpp +2 -2
  6. data/vendor/faiss/faiss/AutoTune.h +3 -3
  7. data/vendor/faiss/faiss/Clustering.cpp +37 -6
  8. data/vendor/faiss/faiss/Clustering.h +12 -3
  9. data/vendor/faiss/faiss/IVFlib.cpp +6 -3
  10. data/vendor/faiss/faiss/IVFlib.h +2 -2
  11. data/vendor/faiss/faiss/Index.cpp +6 -2
  12. data/vendor/faiss/faiss/Index.h +30 -8
  13. data/vendor/faiss/faiss/Index2Layer.cpp +2 -2
  14. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  15. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +7 -7
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +14 -16
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +2 -2
  19. data/vendor/faiss/faiss/IndexBinary.cpp +13 -2
  20. data/vendor/faiss/faiss/IndexBinary.h +8 -2
  21. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -2
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -2
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +2 -7
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.h +3 -3
  27. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.h +2 -2
  29. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  30. data/vendor/faiss/faiss/IndexBinaryIVF.h +2 -2
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +32 -18
  32. data/vendor/faiss/faiss/IndexFastScan.h +11 -2
  33. data/vendor/faiss/faiss/IndexFlat.cpp +13 -10
  34. data/vendor/faiss/faiss/IndexFlat.h +2 -2
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -7
  36. data/vendor/faiss/faiss/IndexFlatCodes.h +25 -5
  37. data/vendor/faiss/faiss/IndexHNSW.cpp +156 -96
  38. data/vendor/faiss/faiss/IndexHNSW.h +54 -5
  39. data/vendor/faiss/faiss/IndexIDMap.cpp +19 -3
  40. data/vendor/faiss/faiss/IndexIDMap.h +5 -2
  41. data/vendor/faiss/faiss/IndexIVF.cpp +5 -6
  42. data/vendor/faiss/faiss/IndexIVF.h +13 -4
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +21 -7
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +5 -2
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +3 -14
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -4
  47. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +201 -91
  48. data/vendor/faiss/faiss/IndexIVFFastScan.h +33 -9
  49. data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -2
  50. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -2
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +2 -2
  52. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +2 -2
  53. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -6
  54. data/vendor/faiss/faiss/IndexIVFPQ.h +2 -2
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +7 -14
  56. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -4
  57. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -2
  58. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -2
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -3
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -2
  61. data/vendor/faiss/faiss/IndexLSH.cpp +2 -3
  62. data/vendor/faiss/faiss/IndexLSH.h +2 -2
  63. data/vendor/faiss/faiss/IndexLattice.cpp +3 -21
  64. data/vendor/faiss/faiss/IndexLattice.h +5 -24
  65. data/vendor/faiss/faiss/IndexNNDescent.cpp +2 -31
  66. data/vendor/faiss/faiss/IndexNNDescent.h +3 -3
  67. data/vendor/faiss/faiss/IndexNSG.cpp +2 -5
  68. data/vendor/faiss/faiss/IndexNSG.h +3 -3
  69. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  70. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  71. data/vendor/faiss/faiss/IndexPQ.cpp +26 -26
  72. data/vendor/faiss/faiss/IndexPQ.h +2 -2
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +2 -5
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +2 -11
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +2 -2
  76. data/vendor/faiss/faiss/IndexPreTransform.h +3 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +46 -9
  78. data/vendor/faiss/faiss/IndexRefine.h +9 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +2 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -2
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +2 -2
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +2 -2
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +5 -4
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -2
  85. data/vendor/faiss/faiss/IndexShards.cpp +2 -2
  86. data/vendor/faiss/faiss/IndexShards.h +2 -2
  87. data/vendor/faiss/faiss/IndexShardsIVF.cpp +2 -2
  88. data/vendor/faiss/faiss/IndexShardsIVF.h +2 -2
  89. data/vendor/faiss/faiss/MatrixStats.cpp +2 -2
  90. data/vendor/faiss/faiss/MatrixStats.h +2 -2
  91. data/vendor/faiss/faiss/MetaIndexes.cpp +2 -3
  92. data/vendor/faiss/faiss/MetaIndexes.h +2 -2
  93. data/vendor/faiss/faiss/MetricType.h +9 -4
  94. data/vendor/faiss/faiss/VectorTransform.cpp +2 -2
  95. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  96. data/vendor/faiss/faiss/clone_index.cpp +2 -2
  97. data/vendor/faiss/faiss/clone_index.h +2 -2
  98. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +2 -2
  99. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +2 -2
  100. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +97 -19
  101. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +192 -0
  102. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +29 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2 -2
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +85 -32
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2 -2
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +2 -2
  107. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +2 -2
  108. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +2 -2
  109. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +2 -2
  110. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +2 -2
  111. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -5
  112. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +2 -2
  113. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +45 -13
  114. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  115. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +12 -6
  116. data/vendor/faiss/faiss/gpu/GpuDistance.h +11 -7
  117. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +3 -3
  118. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +2 -2
  119. data/vendor/faiss/faiss/gpu/GpuIndex.h +10 -15
  120. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -2
  121. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +285 -0
  122. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +2 -2
  123. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -2
  124. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +4 -2
  125. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +3 -3
  126. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +2 -2
  127. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +2 -2
  128. data/vendor/faiss/faiss/gpu/GpuResources.cpp +7 -2
  129. data/vendor/faiss/faiss/gpu/GpuResources.h +11 -4
  130. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +66 -11
  131. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +15 -5
  132. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -2
  133. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +28 -23
  134. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +2 -2
  135. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +2 -2
  136. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +2 -2
  137. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  138. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +2 -2
  139. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +8 -2
  140. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +2 -3
  141. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +2 -2
  142. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +10 -7
  143. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +2 -2
  144. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +54 -54
  145. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +144 -77
  146. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +51 -51
  147. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +2 -2
  148. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +3 -3
  149. data/vendor/faiss/faiss/gpu/test/TestGpuResidualQuantizer.cpp +70 -0
  150. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +74 -4
  151. data/vendor/faiss/faiss/gpu/test/TestUtils.h +2 -2
  152. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +3 -3
  153. data/vendor/faiss/faiss/gpu/utils/{RaftUtils.h → CuvsUtils.h} +12 -11
  154. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +8 -2
  155. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +2 -2
  156. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +2 -2
  157. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +2 -2
  158. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +6 -3
  159. data/vendor/faiss/faiss/gpu/utils/Timer.h +3 -3
  160. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +79 -11
  161. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +17 -5
  162. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +27 -2
  163. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +11 -3
  164. data/vendor/faiss/faiss/impl/CodePacker.cpp +2 -2
  165. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  166. data/vendor/faiss/faiss/impl/DistanceComputer.h +48 -2
  167. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -4
  168. data/vendor/faiss/faiss/impl/FaissException.cpp +2 -2
  169. data/vendor/faiss/faiss/impl/FaissException.h +2 -3
  170. data/vendor/faiss/faiss/impl/HNSW.cpp +378 -205
  171. data/vendor/faiss/faiss/impl/HNSW.h +55 -24
  172. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  173. data/vendor/faiss/faiss/impl/IDSelector.h +2 -2
  174. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +10 -10
  175. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +2 -2
  176. data/vendor/faiss/faiss/impl/LookupTableScaler.h +36 -2
  177. data/vendor/faiss/faiss/impl/NNDescent.cpp +15 -10
  178. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  179. data/vendor/faiss/faiss/impl/NSG.cpp +26 -49
  180. data/vendor/faiss/faiss/impl/NSG.h +20 -8
  181. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +2 -2
  182. data/vendor/faiss/faiss/impl/PolysemousTraining.h +2 -2
  183. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +2 -4
  184. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +2 -2
  185. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -2
  186. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +3 -2
  187. data/vendor/faiss/faiss/impl/ProductQuantizer.h +7 -3
  188. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  189. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +2 -36
  190. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +3 -13
  191. data/vendor/faiss/faiss/impl/ResultHandler.h +153 -34
  192. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +721 -104
  193. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +5 -2
  194. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +2 -2
  195. data/vendor/faiss/faiss/impl/ThreadedIndex.h +2 -2
  196. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +7 -2
  197. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  198. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +2 -2
  199. data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +440 -0
  200. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +55 -2
  201. data/vendor/faiss/faiss/impl/index_read.cpp +31 -20
  202. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  203. data/vendor/faiss/faiss/impl/index_write.cpp +30 -16
  204. data/vendor/faiss/faiss/impl/io.cpp +15 -7
  205. data/vendor/faiss/faiss/impl/io.h +6 -6
  206. data/vendor/faiss/faiss/impl/io_macros.h +8 -9
  207. data/vendor/faiss/faiss/impl/kmeans1d.cpp +2 -3
  208. data/vendor/faiss/faiss/impl/kmeans1d.h +2 -2
  209. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +2 -3
  210. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  211. data/vendor/faiss/faiss/impl/platform_macros.h +34 -2
  212. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +13 -2
  213. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +20 -2
  214. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +3 -3
  215. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +450 -3
  216. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +8 -8
  217. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +3 -3
  218. data/vendor/faiss/faiss/impl/simd_result_handlers.h +151 -67
  219. data/vendor/faiss/faiss/index_factory.cpp +51 -34
  220. data/vendor/faiss/faiss/index_factory.h +2 -2
  221. data/vendor/faiss/faiss/index_io.h +14 -7
  222. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +30 -10
  223. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +5 -2
  224. data/vendor/faiss/faiss/invlists/DirectMap.cpp +11 -3
  225. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -2
  226. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +57 -19
  227. data/vendor/faiss/faiss/invlists/InvertedLists.h +20 -11
  228. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +2 -2
  229. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +2 -2
  230. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +23 -9
  231. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +4 -3
  232. data/vendor/faiss/faiss/python/python_callbacks.cpp +5 -5
  233. data/vendor/faiss/faiss/python/python_callbacks.h +2 -2
  234. data/vendor/faiss/faiss/utils/AlignedTable.h +5 -3
  235. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  236. data/vendor/faiss/faiss/utils/Heap.h +107 -2
  237. data/vendor/faiss/faiss/utils/NeuralNet.cpp +346 -0
  238. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  239. data/vendor/faiss/faiss/utils/WorkerThread.cpp +2 -2
  240. data/vendor/faiss/faiss/utils/WorkerThread.h +2 -2
  241. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  242. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  243. data/vendor/faiss/faiss/utils/approx_topk/generic.h +2 -2
  244. data/vendor/faiss/faiss/utils/approx_topk/mode.h +2 -2
  245. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +2 -2
  246. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  247. data/vendor/faiss/faiss/utils/distances.cpp +249 -90
  248. data/vendor/faiss/faiss/utils/distances.h +8 -8
  249. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +2 -2
  250. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  251. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  252. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +2 -2
  253. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +2 -2
  254. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +2 -2
  255. data/vendor/faiss/faiss/utils/distances_simd.cpp +1543 -56
  256. data/vendor/faiss/faiss/utils/extra_distances-inl.h +72 -2
  257. data/vendor/faiss/faiss/utils/extra_distances.cpp +87 -140
  258. data/vendor/faiss/faiss/utils/extra_distances.h +5 -4
  259. data/vendor/faiss/faiss/utils/fp16-arm.h +2 -2
  260. data/vendor/faiss/faiss/utils/fp16-fp16c.h +2 -2
  261. data/vendor/faiss/faiss/utils/fp16-inl.h +2 -2
  262. data/vendor/faiss/faiss/utils/fp16.h +2 -2
  263. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -2
  264. data/vendor/faiss/faiss/utils/hamming.cpp +3 -4
  265. data/vendor/faiss/faiss/utils/hamming.h +2 -2
  266. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +2 -2
  267. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +490 -0
  268. data/vendor/faiss/faiss/utils/hamming_distance/common.h +2 -2
  269. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +6 -3
  270. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +7 -3
  271. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +5 -5
  272. data/vendor/faiss/faiss/utils/ordered_key_value.h +2 -2
  273. data/vendor/faiss/faiss/utils/partitioning.cpp +2 -2
  274. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  275. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  276. data/vendor/faiss/faiss/utils/quantize_lut.cpp +2 -2
  277. data/vendor/faiss/faiss/utils/quantize_lut.h +2 -2
  278. data/vendor/faiss/faiss/utils/random.cpp +45 -2
  279. data/vendor/faiss/faiss/utils/random.h +27 -2
  280. data/vendor/faiss/faiss/utils/simdlib.h +12 -3
  281. data/vendor/faiss/faiss/utils/simdlib_avx2.h +2 -2
  282. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  283. data/vendor/faiss/faiss/utils/simdlib_emulated.h +2 -2
  284. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -4
  285. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  286. data/vendor/faiss/faiss/utils/sorting.cpp +2 -2
  287. data/vendor/faiss/faiss/utils/sorting.h +2 -2
  288. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +2 -2
  289. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  290. data/vendor/faiss/faiss/utils/utils.cpp +17 -10
  291. data/vendor/faiss/faiss/utils/utils.h +7 -3
  292. metadata +22 -11
  293. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -26,6 +26,12 @@ struct GpuIndexIVFConfig : public GpuIndexConfig {
26
26
 
27
27
  /// Configuration for the coarse quantizer object
28
28
  GpuIndexFlatConfig flatConfig;
29
+
30
+ /// This flag controls the CPU fallback logic for coarse quantizer
31
+ /// component of the index. When set to false (default), the cloner will
32
+ /// throw an exception for indices not implemented on GPU. When set to
33
+ /// true, it will fallback to a CPU implementation.
34
+ bool allowCpuCoarseQuantizer = false;
29
35
  };
30
36
 
31
37
  /// Base class of all GPU IVF index types. This (for now) deliberately does not
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -87,6 +87,8 @@ class GpuIndexIVFFlat : public GpuIndexIVF {
87
87
  /// Trains the coarse quantizer based on the given vector data
88
88
  void train(idx_t n, const float* x) override;
89
89
 
90
+ void reconstruct_n(idx_t i0, idx_t n, float* out) const override;
91
+
90
92
  protected:
91
93
  /// Initialize appropriate index
92
94
  void setIndex_(
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -34,7 +34,7 @@ struct GpuIndexIVFPQConfig : public GpuIndexIVFConfig {
34
34
 
35
35
  /// Use the alternative memory layout for the IVF lists
36
36
  /// WARNING: this is a feature under development, and is only supported with
37
- /// RAFT enabled for the index. Do not use if RAFT is not enabled.
37
+ /// cuVS enabled for the index. Do not use if cuVS is not enabled.
38
38
  bool interleavedLayout = false;
39
39
 
40
40
  /// Use GEMM-backed computation of PQ code distances for the no precomputed
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,3 +1,4 @@
1
+ // @lint-ignore-every LICENSELINT
1
2
  /**
2
3
  * Copyright (c) Facebook, Inc. and its affiliates.
3
4
  *
@@ -5,7 +6,7 @@
5
6
  * LICENSE file in the root directory of this source tree.
6
7
  */
7
8
  /*
8
- * Copyright (c) 2023, NVIDIA CORPORATION.
9
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION.
9
10
  *
10
11
  * Licensed under the Apache License, Version 2.0 (the "License");
11
12
  * you may not use this file except in compliance with the License.
@@ -160,6 +161,10 @@ GpuMemoryReservation::~GpuMemoryReservation() {
160
161
 
161
162
  GpuResources::~GpuResources() = default;
162
163
 
164
+ bool GpuResources::supportsBFloat16CurrentDevice() {
165
+ return supportsBFloat16(getCurrentDevice());
166
+ }
167
+
163
168
  cublasHandle_t GpuResources::getBlasHandleCurrentDevice() {
164
169
  return getBlasHandle(getCurrentDevice());
165
170
  }
@@ -168,7 +173,7 @@ cudaStream_t GpuResources::getDefaultStreamCurrentDevice() {
168
173
  return getDefaultStream(getCurrentDevice());
169
174
  }
170
175
 
171
- #if defined USE_NVIDIA_RAFT
176
+ #if defined USE_NVIDIA_CUVS
172
177
  raft::device_resources& GpuResources::getRaftHandleCurrentDevice() {
173
178
  return getRaftHandle(getCurrentDevice());
174
179
  }
@@ -1,3 +1,4 @@
1
+ // @lint-ignore-every LICENSELINT
1
2
  /**
2
3
  * Copyright (c) Facebook, Inc. and its affiliates.
3
4
  *
@@ -5,7 +6,7 @@
5
6
  * LICENSE file in the root directory of this source tree.
6
7
  */
7
8
  /*
8
- * Copyright (c) 2023, NVIDIA CORPORATION.
9
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION.
9
10
  *
10
11
  * Licensed under the Apache License, Version 2.0 (the "License");
11
12
  * you may not use this file except in compliance with the License.
@@ -30,7 +31,7 @@
30
31
  #include <utility>
31
32
  #include <vector>
32
33
 
33
- #if defined USE_NVIDIA_RAFT
34
+ #if defined USE_NVIDIA_CUVS
34
35
  #include <raft/core/device_resources.hpp>
35
36
  #include <rmm/mr/device/device_memory_resource.hpp>
36
37
  #endif
@@ -161,7 +162,7 @@ struct AllocRequest : public AllocInfo {
161
162
  /// The size in bytes of the allocation
162
163
  size_t size = 0;
163
164
 
164
- #if defined USE_NVIDIA_RAFT
165
+ #if defined USE_NVIDIA_CUVS
165
166
  rmm::mr::device_memory_resource* mr = nullptr;
166
167
  #endif
167
168
  };
@@ -204,6 +205,9 @@ class GpuResources {
204
205
  /// of demand
205
206
  virtual void initializeForDevice(int device) = 0;
206
207
 
208
+ /// Does the given GPU support bfloat16?
209
+ virtual bool supportsBFloat16(int device) = 0;
210
+
207
211
  /// Returns the cuBLAS handle that we use for the given device
208
212
  virtual cublasHandle_t getBlasHandle(int device) = 0;
209
213
 
@@ -211,7 +215,7 @@ class GpuResources {
211
215
  /// given device
212
216
  virtual cudaStream_t getDefaultStream(int device) = 0;
213
217
 
214
- #if defined USE_NVIDIA_RAFT
218
+ #if defined USE_NVIDIA_CUVS
215
219
  /// Returns the raft handle for the given device which can be used to
216
220
  /// make calls to other raft primitives.
217
221
  virtual raft::device_resources& getRaftHandle(int device) = 0;
@@ -251,6 +255,9 @@ class GpuResources {
251
255
  /// Functions provided by default
252
256
  ///
253
257
 
258
+ /// Does the current GPU support bfloat16?
259
+ bool supportsBFloat16CurrentDevice();
260
+
254
261
  /// Calls getBlasHandle with the current device
255
262
  cublasHandle_t getBlasHandleCurrentDevice();
256
263
 
@@ -1,3 +1,4 @@
1
+ // @lint-ignore-every LICENSELINT
1
2
  /**
2
3
  * Copyright (c) Facebook, Inc. and its affiliates.
3
4
  *
@@ -5,7 +6,7 @@
5
6
  * LICENSE file in the root directory of this source tree.
6
7
  */
7
8
  /*
8
- * Copyright (c) 2023, NVIDIA CORPORATION.
9
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION.
9
10
  *
10
11
  * Licensed under the Apache License, Version 2.0 (the "License");
11
12
  * you may not use this file except in compliance with the License.
@@ -20,7 +21,7 @@
20
21
  * limitations under the License.
21
22
  */
22
23
 
23
- #if defined USE_NVIDIA_RAFT
24
+ #if defined USE_NVIDIA_CUVS
24
25
  #include <raft/core/device_resources.hpp>
25
26
  #include <rmm/mr/device/managed_memory_resource.hpp>
26
27
  #include <rmm/mr/device/per_device_resource.hpp>
@@ -90,7 +91,7 @@ std::string allocsToString(const std::unordered_map<void*, AllocRequest>& map) {
90
91
 
91
92
  StandardGpuResourcesImpl::StandardGpuResourcesImpl()
92
93
  :
93
- #if defined USE_NVIDIA_RAFT
94
+ #if defined USE_NVIDIA_CUVS
94
95
  mmr_(new rmm::mr::managed_memory_resource),
95
96
  pmr_(new rmm::mr::pinned_memory_resource),
96
97
  #endif
@@ -129,6 +130,10 @@ StandardGpuResourcesImpl::~StandardGpuResourcesImpl() {
129
130
  FAISS_ASSERT_MSG(
130
131
  !allocError, "GPU memory allocations not properly cleaned up");
131
132
 
133
+ #if defined USE_NVIDIA_CUVS
134
+ raftHandles_.clear();
135
+ #endif
136
+
132
137
  for (auto& entry : defaultStreams_) {
133
138
  DeviceScope scope(entry.first);
134
139
 
@@ -158,7 +163,7 @@ StandardGpuResourcesImpl::~StandardGpuResourcesImpl() {
158
163
  }
159
164
 
160
165
  if (pinnedMemAlloc_) {
161
- #if defined USE_NVIDIA_RAFT
166
+ #if defined USE_NVIDIA_CUVS
162
167
  pmr_->deallocate(pinnedMemAlloc_, pinnedMemAllocSize_);
163
168
  #else
164
169
  auto err = cudaFreeHost(pinnedMemAlloc_);
@@ -201,6 +206,13 @@ size_t StandardGpuResourcesImpl::getDefaultTempMemForGPU(
201
206
  return requested;
202
207
  }
203
208
 
209
+ /// Does the given GPU support bfloat16?
210
+ bool StandardGpuResourcesImpl::supportsBFloat16(int device) {
211
+ initializeForDevice(device);
212
+ auto& prop = getDeviceProperties(device);
213
+ return prop.major >= 8;
214
+ }
215
+
204
216
  void StandardGpuResourcesImpl::noTempMemory() {
205
217
  setTempMemory(0);
206
218
  }
@@ -257,6 +269,14 @@ void StandardGpuResourcesImpl::setDefaultStream(
257
269
  if (prevStream != stream) {
258
270
  streamWait({stream}, {prevStream});
259
271
  }
272
+ #if defined USE_NVIDIA_CUVS
273
+ // delete the raft handle for this device, which will be initialized
274
+ // with the updated stream during any subsequent calls to getRaftHandle
275
+ auto it2 = raftHandles_.find(device);
276
+ if (it2 != raftHandles_.end()) {
277
+ raft::resource::set_cuda_stream(it2->second, stream);
278
+ }
279
+ #endif
260
280
  }
261
281
 
262
282
  userDefaultStreams_[device] = stream;
@@ -274,6 +294,24 @@ void StandardGpuResourcesImpl::revertDefaultStream(int device) {
274
294
  cudaStream_t newStream = defaultStreams_[device];
275
295
 
276
296
  streamWait({newStream}, {prevStream});
297
+
298
+ #if defined USE_NVIDIA_CUVS
299
+ // update the stream on the raft handle for this device
300
+ auto it2 = raftHandles_.find(device);
301
+ if (it2 != raftHandles_.end()) {
302
+ raft::resource::set_cuda_stream(it2->second, newStream);
303
+ }
304
+ #endif
305
+ } else {
306
+ #if defined USE_NVIDIA_CUVS
307
+ // delete the raft handle for this device, which will be initialized
308
+ // with the updated stream during any subsequent calls to
309
+ // getRaftHandle
310
+ auto it2 = raftHandles_.find(device);
311
+ if (it2 != raftHandles_.end()) {
312
+ raftHandles_.erase(it2);
313
+ }
314
+ #endif
277
315
  }
278
316
  }
279
317
 
@@ -307,7 +345,7 @@ void StandardGpuResourcesImpl::initializeForDevice(int device) {
307
345
  // If this is the first device that we're initializing, create our
308
346
  // pinned memory allocation
309
347
  if (defaultStreams_.empty() && pinnedMemSize_ > 0) {
310
- #if defined USE_NVIDIA_RAFT
348
+ #if defined USE_NVIDIA_CUVS
311
349
  // If this is the first device that we're initializing, create our
312
350
  // pinned memory allocation
313
351
  if (defaultStreams_.empty() && pinnedMemSize_ > 0) {
@@ -347,11 +385,20 @@ void StandardGpuResourcesImpl::initializeForDevice(int device) {
347
385
  prop.major,
348
386
  prop.minor);
349
387
 
388
+ #if USE_AMD_ROCM
389
+ // Our code is pre-built with and expects warpSize == 32 or 64, validate
390
+ // that
391
+ FAISS_ASSERT_FMT(
392
+ prop.warpSize == 32 || prop.warpSize == 64,
393
+ "Device id %d does not have expected warpSize of 32 or 64",
394
+ device);
395
+ #else
350
396
  // Our code is pre-built with and expects warpSize == 32, validate that
351
397
  FAISS_ASSERT_FMT(
352
398
  prop.warpSize == 32,
353
399
  "Device id %d does not have expected warpSize of 32",
354
400
  device);
401
+ #endif
355
402
 
356
403
  // Create streams
357
404
  cudaStream_t defaultStream = nullptr;
@@ -360,7 +407,7 @@ void StandardGpuResourcesImpl::initializeForDevice(int device) {
360
407
 
361
408
  defaultStreams_[device] = defaultStream;
362
409
 
363
- #if defined USE_NVIDIA_RAFT
410
+ #if defined USE_NVIDIA_CUVS
364
411
  raftHandles_.emplace(std::make_pair(device, defaultStream));
365
412
  #endif
366
413
 
@@ -426,7 +473,7 @@ cudaStream_t StandardGpuResourcesImpl::getDefaultStream(int device) {
426
473
  return defaultStreams_[device];
427
474
  }
428
475
 
429
- #if defined USE_NVIDIA_RAFT
476
+ #if defined USE_NVIDIA_CUVS
430
477
  raft::device_resources& StandardGpuResourcesImpl::getRaftHandle(int device) {
431
478
  initializeForDevice(device);
432
479
 
@@ -497,7 +544,7 @@ void* StandardGpuResourcesImpl::allocMemory(const AllocRequest& req) {
497
544
  // Otherwise, we can handle this locally
498
545
  p = tempMemory_[adjReq.device]->allocMemory(adjReq.stream, adjReq.size);
499
546
  } else if (adjReq.space == MemorySpace::Device) {
500
- #if defined USE_NVIDIA_RAFT
547
+ #if defined USE_NVIDIA_CUVS
501
548
  try {
502
549
  rmm::mr::device_memory_resource* current_mr =
503
550
  rmm::mr::get_per_device_resource(
@@ -531,7 +578,7 @@ void* StandardGpuResourcesImpl::allocMemory(const AllocRequest& req) {
531
578
  }
532
579
  #endif
533
580
  } else if (adjReq.space == MemorySpace::Unified) {
534
- #if defined USE_NVIDIA_RAFT
581
+ #if defined USE_NVIDIA_CUVS
535
582
  try {
536
583
  // for now, use our own managed MR to do Unified Memory allocations.
537
584
  // TODO: change this to use the current device resource once RMM has
@@ -600,7 +647,7 @@ void StandardGpuResourcesImpl::deallocMemory(int device, void* p) {
600
647
  } else if (
601
648
  req.space == MemorySpace::Device ||
602
649
  req.space == MemorySpace::Unified) {
603
- #if defined USE_NVIDIA_RAFT
650
+ #if defined USE_NVIDIA_CUVS
604
651
  req.mr->deallocate_async(p, req.size, req.stream);
605
652
  #else
606
653
  auto err = cudaFree(p);
@@ -661,6 +708,14 @@ std::shared_ptr<GpuResources> StandardGpuResources::getResources() {
661
708
  return res_;
662
709
  }
663
710
 
711
+ bool StandardGpuResources::supportsBFloat16(int device) {
712
+ return res_->supportsBFloat16(device);
713
+ }
714
+
715
+ bool StandardGpuResources::supportsBFloat16CurrentDevice() {
716
+ return res_->supportsBFloat16CurrentDevice();
717
+ }
718
+
664
719
  void StandardGpuResources::noTempMemory() {
665
720
  res_->noTempMemory();
666
721
  }
@@ -694,7 +749,7 @@ cudaStream_t StandardGpuResources::getDefaultStream(int device) {
694
749
  return res_->getDefaultStream(device);
695
750
  }
696
751
 
697
- #if defined USE_NVIDIA_RAFT
752
+ #if defined USE_NVIDIA_CUVS
698
753
  raft::device_resources& StandardGpuResources::getRaftHandle(int device) {
699
754
  return res_->getRaftHandle(device);
700
755
  }
@@ -1,3 +1,4 @@
1
+ // @lint-ignore-every LICENSELINT
1
2
  /**
2
3
  * Copyright (c) Facebook, Inc. and its affiliates.
3
4
  *
@@ -5,7 +6,7 @@
5
6
  * LICENSE file in the root directory of this source tree.
6
7
  */
7
8
  /*
8
- * Copyright (c) 2023, NVIDIA CORPORATION.
9
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION.
9
10
  *
10
11
  * Licensed under the Apache License, Version 2.0 (the "License");
11
12
  * you may not use this file except in compliance with the License.
@@ -22,7 +23,7 @@
22
23
 
23
24
  #pragma once
24
25
 
25
- #if defined USE_NVIDIA_RAFT
26
+ #if defined USE_NVIDIA_CUVS
26
27
  #include <raft/core/device_resources.hpp>
27
28
  #include <rmm/mr/host/pinned_memory_resource.hpp>
28
29
  #endif
@@ -47,6 +48,9 @@ class StandardGpuResourcesImpl : public GpuResources {
47
48
 
48
49
  ~StandardGpuResourcesImpl() override;
49
50
 
51
+ /// Does the given GPU support bfloat16?
52
+ bool supportsBFloat16(int device) override;
53
+
50
54
  /// Disable allocation of temporary memory; all temporary memory
51
55
  /// requests will call cudaMalloc / cudaFree at the point of use
52
56
  void noTempMemory();
@@ -79,7 +83,7 @@ class StandardGpuResourcesImpl : public GpuResources {
79
83
  /// this stream upon exit from an index or other Faiss GPU call.
80
84
  cudaStream_t getDefaultStream(int device) override;
81
85
 
82
- #if defined USE_NVIDIA_RAFT
86
+ #if defined USE_NVIDIA_CUVS
83
87
  /// Returns the raft handle for the given device which can be used to
84
88
  /// make calls to other raft primitives.
85
89
  raft::device_resources& getRaftHandle(int device) override;
@@ -151,7 +155,7 @@ class StandardGpuResourcesImpl : public GpuResources {
151
155
  /// cuBLAS handle for each device
152
156
  std::unordered_map<int, cublasHandle_t> blasHandles_;
153
157
 
154
- #if defined USE_NVIDIA_RAFT
158
+ #if defined USE_NVIDIA_CUVS
155
159
  /// raft handle for each device
156
160
  std::unordered_map<int, raft::device_resources> raftHandles_;
157
161
 
@@ -198,6 +202,12 @@ class StandardGpuResources : public GpuResourcesProvider {
198
202
 
199
203
  std::shared_ptr<GpuResources> getResources() override;
200
204
 
205
+ /// Whether or not the given device supports native bfloat16 arithmetic
206
+ bool supportsBFloat16(int device);
207
+
208
+ /// Whether or not the current device supports native bfloat16 arithmetic
209
+ bool supportsBFloat16CurrentDevice();
210
+
201
211
  /// Disable allocation of temporary memory; all temporary memory
202
212
  /// requests will call cudaMalloc / cudaFree at the point of use
203
213
  void noTempMemory();
@@ -234,7 +244,7 @@ class StandardGpuResources : public GpuResourcesProvider {
234
244
  /// Returns the current default stream
235
245
  cudaStream_t getDefaultStream(int device);
236
246
 
237
- #if defined USE_NVIDIA_RAFT
247
+ #if defined USE_NVIDIA_CUVS
238
248
  /// Returns the raft handle for the given device which can be used to
239
249
  /// make calls to other raft primitives.
240
250
  raft::device_resources& getRaftHandle(int device);
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,11 +1,12 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
8
  #include <faiss/gpu/impl/InterleavedCodes.h>
9
+ #include <faiss/gpu/utils/DeviceUtils.h>
9
10
  #include <faiss/gpu/utils/StaticUtils.h>
10
11
  #include <faiss/impl/FaissAssert.h>
11
12
 
@@ -166,15 +167,16 @@ void unpackInterleavedWord(
166
167
  int numVecs,
167
168
  int dims,
168
169
  int bitsPerCode) {
169
- int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
170
+ int warpSize = getWarpSizeCurrentDevice();
171
+ int wordsPerDimBlock = (size_t)warpSize * bitsPerCode / (8 * sizeof(T));
170
172
  int wordsPerBlock = wordsPerDimBlock * dims;
171
- int numBlocks = utils::divUp(numVecs, 32);
173
+ int numBlocks = utils::divUp(numVecs, warpSize);
172
174
 
173
175
  #pragma omp parallel for
174
176
  for (int i = 0; i < numVecs; ++i) {
175
- int block = i / 32;
177
+ int block = i / warpSize;
176
178
  FAISS_ASSERT(block < numBlocks);
177
- int lane = i % 32;
179
+ int lane = i % warpSize;
178
180
 
179
181
  for (int j = 0; j < dims; ++j) {
180
182
  int srcOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
@@ -188,9 +190,10 @@ std::vector<uint8_t> unpackInterleaved(
188
190
  int numVecs,
189
191
  int dims,
190
192
  int bitsPerCode) {
191
- int bytesPerDimBlock = 32 * bitsPerCode / 8;
193
+ int warpSize = getWarpSizeCurrentDevice();
194
+ int bytesPerDimBlock = warpSize * bitsPerCode / 8;
192
195
  int bytesPerBlock = bytesPerDimBlock * dims;
193
- int numBlocks = utils::divUp(numVecs, 32);
196
+ int numBlocks = utils::divUp(numVecs, warpSize);
194
197
  size_t totalSize = (size_t)bytesPerBlock * numBlocks;
195
198
  FAISS_ASSERT(data.size() == totalSize);
196
199
 
@@ -217,8 +220,8 @@ std::vector<uint8_t> unpackInterleaved(
217
220
  } else if (bitsPerCode == 4) {
218
221
  #pragma omp parallel for
219
222
  for (int i = 0; i < numVecs; ++i) {
220
- int block = i / 32;
221
- int lane = i % 32;
223
+ int block = i / warpSize;
224
+ int lane = i % warpSize;
222
225
 
223
226
  int word = lane / 2;
224
227
  int subWord = lane % 2;
@@ -235,8 +238,8 @@ std::vector<uint8_t> unpackInterleaved(
235
238
  } else if (bitsPerCode == 5) {
236
239
  #pragma omp parallel for
237
240
  for (int i = 0; i < numVecs; ++i) {
238
- int block = i / 32;
239
- int blockVector = i % 32;
241
+ int block = i / warpSize;
242
+ int blockVector = i % warpSize;
240
243
 
241
244
  for (int j = 0; j < dims; ++j) {
242
245
  uint8_t* dimBlock =
@@ -257,8 +260,8 @@ std::vector<uint8_t> unpackInterleaved(
257
260
  } else if (bitsPerCode == 6) {
258
261
  #pragma omp parallel for
259
262
  for (int i = 0; i < numVecs; ++i) {
260
- int block = i / 32;
261
- int blockVector = i % 32;
263
+ int block = i / warpSize;
264
+ int blockVector = i % warpSize;
262
265
 
263
266
  for (int j = 0; j < dims; ++j) {
264
267
  uint8_t* dimBlock =
@@ -442,17 +445,18 @@ void packInterleavedWord(
442
445
  int numVecs,
443
446
  int dims,
444
447
  int bitsPerCode) {
445
- int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
448
+ int warpSize = getWarpSizeCurrentDevice();
449
+ int wordsPerDimBlock = (size_t)warpSize * bitsPerCode / (8 * sizeof(T));
446
450
  int wordsPerBlock = wordsPerDimBlock * dims;
447
- int numBlocks = utils::divUp(numVecs, 32);
451
+ int numBlocks = utils::divUp(numVecs, warpSize);
448
452
 
449
453
  // We're guaranteed that all other slots not filled by the vectors present
450
454
  // are initialized to zero (from the vector constructor in packInterleaved)
451
455
  #pragma omp parallel for
452
456
  for (int i = 0; i < numVecs; ++i) {
453
- int block = i / 32;
457
+ int block = i / warpSize;
454
458
  FAISS_ASSERT(block < numBlocks);
455
- int lane = i % 32;
459
+ int lane = i % warpSize;
456
460
 
457
461
  for (int j = 0; j < dims; ++j) {
458
462
  int dstOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
@@ -466,9 +470,10 @@ std::vector<uint8_t> packInterleaved(
466
470
  int numVecs,
467
471
  int dims,
468
472
  int bitsPerCode) {
469
- int bytesPerDimBlock = 32 * bitsPerCode / 8;
473
+ int warpSize = getWarpSizeCurrentDevice();
474
+ int bytesPerDimBlock = warpSize * bitsPerCode / 8;
470
475
  int bytesPerBlock = bytesPerDimBlock * dims;
471
- int numBlocks = utils::divUp(numVecs, 32);
476
+ int numBlocks = utils::divUp(numVecs, warpSize);
472
477
  size_t totalSize = (size_t)bytesPerBlock * numBlocks;
473
478
 
474
479
  // bit codes padded to whole bytes
@@ -499,7 +504,7 @@ std::vector<uint8_t> packInterleaved(
499
504
  for (int i = 0; i < numBlocks; ++i) {
500
505
  for (int j = 0; j < dims; ++j) {
501
506
  for (int k = 0; k < bytesPerDimBlock; ++k) {
502
- int loVec = i * 32 + k * 2;
507
+ int loVec = i * warpSize + k * 2;
503
508
  int hiVec = loVec + 1;
504
509
 
505
510
  uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
@@ -516,7 +521,7 @@ std::vector<uint8_t> packInterleaved(
516
521
  for (int j = 0; j < dims; ++j) {
517
522
  for (int k = 0; k < bytesPerDimBlock; ++k) {
518
523
  // What input vectors we are pulling from
519
- int loVec = i * 32 + (k * 8) / 5;
524
+ int loVec = i * warpSize + (k * 8) / 5;
520
525
  int hiVec = loVec + 1;
521
526
  int hiVec2 = hiVec + 1;
522
527
 
@@ -536,7 +541,7 @@ std::vector<uint8_t> packInterleaved(
536
541
  for (int j = 0; j < dims; ++j) {
537
542
  for (int k = 0; k < bytesPerDimBlock; ++k) {
538
543
  // What input vectors we are pulling from
539
- int loVec = i * 32 + (k * 8) / 6;
544
+ int loVec = i * warpSize + (k * 8) / 6;
540
545
  int hiVec = loVec + 1;
541
546
 
542
547
  uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -17,6 +17,7 @@
17
17
  #include <vector>
18
18
 
19
19
  #include <cuda_profiler_api.h>
20
+ #include <faiss/impl/AuxIndexStructures.h>
20
21
 
21
22
  DEFINE_int32(num, 10000, "# of vecs");
22
23
  DEFINE_int32(k, 100, "# of clusters");
@@ -34,6 +35,7 @@ DEFINE_int64(
34
35
  "minimum size to use CPU -> GPU paged copies");
35
36
  DEFINE_int64(pinned_mem, -1, "pinned memory allocation to use");
36
37
  DEFINE_int32(max_points, -1, "max points per centroid");
38
+ DEFINE_double(timeout, 0, "timeout in seconds");
37
39
 
38
40
  using namespace faiss::gpu;
39
41
 
@@ -99,10 +101,14 @@ int main(int argc, char** argv) {
99
101
  cp.max_points_per_centroid = FLAGS_max_points;
100
102
  }
101
103
 
104
+ auto tc = new faiss::TimeoutCallback();
105
+ faiss::InterruptCallback::instance.reset(tc);
106
+
102
107
  faiss::Clustering kmeans(FLAGS_dim, FLAGS_k, cp);
103
108
 
104
109
  // Time k-means
105
110
  {
111
+ tc->set_timeout(FLAGS_timeout);
106
112
  CpuTimer timer;
107
113
 
108
114
  kmeans.train(FLAGS_num, vecs.data(), *(gpuIndex.getIndex()));