faiss 0.3.1 → 0.3.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (293) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/AutoTune.cpp +2 -2
  6. data/vendor/faiss/faiss/AutoTune.h +3 -3
  7. data/vendor/faiss/faiss/Clustering.cpp +37 -6
  8. data/vendor/faiss/faiss/Clustering.h +12 -3
  9. data/vendor/faiss/faiss/IVFlib.cpp +6 -3
  10. data/vendor/faiss/faiss/IVFlib.h +2 -2
  11. data/vendor/faiss/faiss/Index.cpp +6 -2
  12. data/vendor/faiss/faiss/Index.h +30 -8
  13. data/vendor/faiss/faiss/Index2Layer.cpp +2 -2
  14. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  15. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +7 -7
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +14 -16
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +2 -2
  19. data/vendor/faiss/faiss/IndexBinary.cpp +13 -2
  20. data/vendor/faiss/faiss/IndexBinary.h +8 -2
  21. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -2
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -2
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +2 -7
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.h +3 -3
  27. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.h +2 -2
  29. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  30. data/vendor/faiss/faiss/IndexBinaryIVF.h +2 -2
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +32 -18
  32. data/vendor/faiss/faiss/IndexFastScan.h +11 -2
  33. data/vendor/faiss/faiss/IndexFlat.cpp +13 -10
  34. data/vendor/faiss/faiss/IndexFlat.h +2 -2
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -7
  36. data/vendor/faiss/faiss/IndexFlatCodes.h +25 -5
  37. data/vendor/faiss/faiss/IndexHNSW.cpp +156 -96
  38. data/vendor/faiss/faiss/IndexHNSW.h +54 -5
  39. data/vendor/faiss/faiss/IndexIDMap.cpp +19 -3
  40. data/vendor/faiss/faiss/IndexIDMap.h +5 -2
  41. data/vendor/faiss/faiss/IndexIVF.cpp +5 -6
  42. data/vendor/faiss/faiss/IndexIVF.h +13 -4
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +21 -7
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +5 -2
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +3 -14
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -4
  47. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +201 -91
  48. data/vendor/faiss/faiss/IndexIVFFastScan.h +33 -9
  49. data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -2
  50. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -2
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +2 -2
  52. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +2 -2
  53. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -6
  54. data/vendor/faiss/faiss/IndexIVFPQ.h +2 -2
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +7 -14
  56. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -4
  57. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -2
  58. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -2
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -3
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -2
  61. data/vendor/faiss/faiss/IndexLSH.cpp +2 -3
  62. data/vendor/faiss/faiss/IndexLSH.h +2 -2
  63. data/vendor/faiss/faiss/IndexLattice.cpp +3 -21
  64. data/vendor/faiss/faiss/IndexLattice.h +5 -24
  65. data/vendor/faiss/faiss/IndexNNDescent.cpp +2 -31
  66. data/vendor/faiss/faiss/IndexNNDescent.h +3 -3
  67. data/vendor/faiss/faiss/IndexNSG.cpp +2 -5
  68. data/vendor/faiss/faiss/IndexNSG.h +3 -3
  69. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  70. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  71. data/vendor/faiss/faiss/IndexPQ.cpp +26 -26
  72. data/vendor/faiss/faiss/IndexPQ.h +2 -2
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +2 -5
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +2 -11
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +2 -2
  76. data/vendor/faiss/faiss/IndexPreTransform.h +3 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +46 -9
  78. data/vendor/faiss/faiss/IndexRefine.h +9 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +2 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -2
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +2 -2
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +2 -2
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +5 -4
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -2
  85. data/vendor/faiss/faiss/IndexShards.cpp +2 -2
  86. data/vendor/faiss/faiss/IndexShards.h +2 -2
  87. data/vendor/faiss/faiss/IndexShardsIVF.cpp +2 -2
  88. data/vendor/faiss/faiss/IndexShardsIVF.h +2 -2
  89. data/vendor/faiss/faiss/MatrixStats.cpp +2 -2
  90. data/vendor/faiss/faiss/MatrixStats.h +2 -2
  91. data/vendor/faiss/faiss/MetaIndexes.cpp +2 -3
  92. data/vendor/faiss/faiss/MetaIndexes.h +2 -2
  93. data/vendor/faiss/faiss/MetricType.h +9 -4
  94. data/vendor/faiss/faiss/VectorTransform.cpp +2 -2
  95. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  96. data/vendor/faiss/faiss/clone_index.cpp +2 -2
  97. data/vendor/faiss/faiss/clone_index.h +2 -2
  98. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +2 -2
  99. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +2 -2
  100. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +97 -19
  101. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +192 -0
  102. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +29 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2 -2
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +85 -32
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2 -2
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +2 -2
  107. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +2 -2
  108. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +2 -2
  109. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +2 -2
  110. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +2 -2
  111. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -5
  112. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +2 -2
  113. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +45 -13
  114. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  115. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +12 -6
  116. data/vendor/faiss/faiss/gpu/GpuDistance.h +11 -7
  117. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +3 -3
  118. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +2 -2
  119. data/vendor/faiss/faiss/gpu/GpuIndex.h +10 -15
  120. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -2
  121. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +285 -0
  122. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +2 -2
  123. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -2
  124. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +4 -2
  125. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +3 -3
  126. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +2 -2
  127. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +2 -2
  128. data/vendor/faiss/faiss/gpu/GpuResources.cpp +7 -2
  129. data/vendor/faiss/faiss/gpu/GpuResources.h +11 -4
  130. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +66 -11
  131. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +15 -5
  132. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -2
  133. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +28 -23
  134. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +2 -2
  135. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +2 -2
  136. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +2 -2
  137. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  138. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +2 -2
  139. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +8 -2
  140. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +2 -3
  141. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +2 -2
  142. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +10 -7
  143. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +2 -2
  144. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +54 -54
  145. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +144 -77
  146. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +51 -51
  147. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +2 -2
  148. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +3 -3
  149. data/vendor/faiss/faiss/gpu/test/TestGpuResidualQuantizer.cpp +70 -0
  150. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +74 -4
  151. data/vendor/faiss/faiss/gpu/test/TestUtils.h +2 -2
  152. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +3 -3
  153. data/vendor/faiss/faiss/gpu/utils/{RaftUtils.h → CuvsUtils.h} +12 -11
  154. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +8 -2
  155. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +2 -2
  156. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +2 -2
  157. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +2 -2
  158. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +6 -3
  159. data/vendor/faiss/faiss/gpu/utils/Timer.h +3 -3
  160. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +79 -11
  161. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +17 -5
  162. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +27 -2
  163. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +11 -3
  164. data/vendor/faiss/faiss/impl/CodePacker.cpp +2 -2
  165. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  166. data/vendor/faiss/faiss/impl/DistanceComputer.h +48 -2
  167. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -4
  168. data/vendor/faiss/faiss/impl/FaissException.cpp +2 -2
  169. data/vendor/faiss/faiss/impl/FaissException.h +2 -3
  170. data/vendor/faiss/faiss/impl/HNSW.cpp +378 -205
  171. data/vendor/faiss/faiss/impl/HNSW.h +55 -24
  172. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  173. data/vendor/faiss/faiss/impl/IDSelector.h +2 -2
  174. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +10 -10
  175. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +2 -2
  176. data/vendor/faiss/faiss/impl/LookupTableScaler.h +36 -2
  177. data/vendor/faiss/faiss/impl/NNDescent.cpp +15 -10
  178. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  179. data/vendor/faiss/faiss/impl/NSG.cpp +26 -49
  180. data/vendor/faiss/faiss/impl/NSG.h +20 -8
  181. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +2 -2
  182. data/vendor/faiss/faiss/impl/PolysemousTraining.h +2 -2
  183. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +2 -4
  184. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +2 -2
  185. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -2
  186. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +3 -2
  187. data/vendor/faiss/faiss/impl/ProductQuantizer.h +7 -3
  188. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  189. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +2 -36
  190. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +3 -13
  191. data/vendor/faiss/faiss/impl/ResultHandler.h +153 -34
  192. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +721 -104
  193. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +5 -2
  194. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +2 -2
  195. data/vendor/faiss/faiss/impl/ThreadedIndex.h +2 -2
  196. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +7 -2
  197. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  198. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +2 -2
  199. data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +440 -0
  200. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +55 -2
  201. data/vendor/faiss/faiss/impl/index_read.cpp +31 -20
  202. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  203. data/vendor/faiss/faiss/impl/index_write.cpp +30 -16
  204. data/vendor/faiss/faiss/impl/io.cpp +15 -7
  205. data/vendor/faiss/faiss/impl/io.h +6 -6
  206. data/vendor/faiss/faiss/impl/io_macros.h +8 -9
  207. data/vendor/faiss/faiss/impl/kmeans1d.cpp +2 -3
  208. data/vendor/faiss/faiss/impl/kmeans1d.h +2 -2
  209. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +2 -3
  210. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  211. data/vendor/faiss/faiss/impl/platform_macros.h +34 -2
  212. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +13 -2
  213. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +20 -2
  214. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +3 -3
  215. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +450 -3
  216. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +8 -8
  217. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +3 -3
  218. data/vendor/faiss/faiss/impl/simd_result_handlers.h +151 -67
  219. data/vendor/faiss/faiss/index_factory.cpp +51 -34
  220. data/vendor/faiss/faiss/index_factory.h +2 -2
  221. data/vendor/faiss/faiss/index_io.h +14 -7
  222. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +30 -10
  223. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +5 -2
  224. data/vendor/faiss/faiss/invlists/DirectMap.cpp +11 -3
  225. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -2
  226. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +57 -19
  227. data/vendor/faiss/faiss/invlists/InvertedLists.h +20 -11
  228. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +2 -2
  229. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +2 -2
  230. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +23 -9
  231. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +4 -3
  232. data/vendor/faiss/faiss/python/python_callbacks.cpp +5 -5
  233. data/vendor/faiss/faiss/python/python_callbacks.h +2 -2
  234. data/vendor/faiss/faiss/utils/AlignedTable.h +5 -3
  235. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  236. data/vendor/faiss/faiss/utils/Heap.h +107 -2
  237. data/vendor/faiss/faiss/utils/NeuralNet.cpp +346 -0
  238. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  239. data/vendor/faiss/faiss/utils/WorkerThread.cpp +2 -2
  240. data/vendor/faiss/faiss/utils/WorkerThread.h +2 -2
  241. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  242. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  243. data/vendor/faiss/faiss/utils/approx_topk/generic.h +2 -2
  244. data/vendor/faiss/faiss/utils/approx_topk/mode.h +2 -2
  245. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +2 -2
  246. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  247. data/vendor/faiss/faiss/utils/distances.cpp +249 -90
  248. data/vendor/faiss/faiss/utils/distances.h +8 -8
  249. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +2 -2
  250. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  251. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  252. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +2 -2
  253. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +2 -2
  254. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +2 -2
  255. data/vendor/faiss/faiss/utils/distances_simd.cpp +1543 -56
  256. data/vendor/faiss/faiss/utils/extra_distances-inl.h +72 -2
  257. data/vendor/faiss/faiss/utils/extra_distances.cpp +87 -140
  258. data/vendor/faiss/faiss/utils/extra_distances.h +5 -4
  259. data/vendor/faiss/faiss/utils/fp16-arm.h +2 -2
  260. data/vendor/faiss/faiss/utils/fp16-fp16c.h +2 -2
  261. data/vendor/faiss/faiss/utils/fp16-inl.h +2 -2
  262. data/vendor/faiss/faiss/utils/fp16.h +2 -2
  263. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -2
  264. data/vendor/faiss/faiss/utils/hamming.cpp +3 -4
  265. data/vendor/faiss/faiss/utils/hamming.h +2 -2
  266. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +2 -2
  267. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +490 -0
  268. data/vendor/faiss/faiss/utils/hamming_distance/common.h +2 -2
  269. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +6 -3
  270. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +7 -3
  271. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +5 -5
  272. data/vendor/faiss/faiss/utils/ordered_key_value.h +2 -2
  273. data/vendor/faiss/faiss/utils/partitioning.cpp +2 -2
  274. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  275. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  276. data/vendor/faiss/faiss/utils/quantize_lut.cpp +2 -2
  277. data/vendor/faiss/faiss/utils/quantize_lut.h +2 -2
  278. data/vendor/faiss/faiss/utils/random.cpp +45 -2
  279. data/vendor/faiss/faiss/utils/random.h +27 -2
  280. data/vendor/faiss/faiss/utils/simdlib.h +12 -3
  281. data/vendor/faiss/faiss/utils/simdlib_avx2.h +2 -2
  282. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  283. data/vendor/faiss/faiss/utils/simdlib_emulated.h +2 -2
  284. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -4
  285. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  286. data/vendor/faiss/faiss/utils/sorting.cpp +2 -2
  287. data/vendor/faiss/faiss/utils/sorting.h +2 -2
  288. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +2 -2
  289. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  290. data/vendor/faiss/faiss/utils/utils.cpp +17 -10
  291. data/vendor/faiss/faiss/utils/utils.h +7 -3
  292. metadata +22 -11
  293. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -26,6 +26,12 @@ struct GpuIndexIVFConfig : public GpuIndexConfig {
26
26
 
27
27
  /// Configuration for the coarse quantizer object
28
28
  GpuIndexFlatConfig flatConfig;
29
+
30
+ /// This flag controls the CPU fallback logic for coarse quantizer
31
+ /// component of the index. When set to false (default), the cloner will
32
+ /// throw an exception for indices not implemented on GPU. When set to
33
+ /// true, it will fallback to a CPU implementation.
34
+ bool allowCpuCoarseQuantizer = false;
29
35
  };
30
36
 
31
37
  /// Base class of all GPU IVF index types. This (for now) deliberately does not
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -87,6 +87,8 @@ class GpuIndexIVFFlat : public GpuIndexIVF {
87
87
  /// Trains the coarse quantizer based on the given vector data
88
88
  void train(idx_t n, const float* x) override;
89
89
 
90
+ void reconstruct_n(idx_t i0, idx_t n, float* out) const override;
91
+
90
92
  protected:
91
93
  /// Initialize appropriate index
92
94
  void setIndex_(
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -34,7 +34,7 @@ struct GpuIndexIVFPQConfig : public GpuIndexIVFConfig {
34
34
 
35
35
  /// Use the alternative memory layout for the IVF lists
36
36
  /// WARNING: this is a feature under development, and is only supported with
37
- /// RAFT enabled for the index. Do not use if RAFT is not enabled.
37
+ /// cuVS enabled for the index. Do not use if cuVS is not enabled.
38
38
  bool interleavedLayout = false;
39
39
 
40
40
  /// Use GEMM-backed computation of PQ code distances for the no precomputed
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,3 +1,4 @@
1
+ // @lint-ignore-every LICENSELINT
1
2
  /**
2
3
  * Copyright (c) Facebook, Inc. and its affiliates.
3
4
  *
@@ -5,7 +6,7 @@
5
6
  * LICENSE file in the root directory of this source tree.
6
7
  */
7
8
  /*
8
- * Copyright (c) 2023, NVIDIA CORPORATION.
9
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION.
9
10
  *
10
11
  * Licensed under the Apache License, Version 2.0 (the "License");
11
12
  * you may not use this file except in compliance with the License.
@@ -160,6 +161,10 @@ GpuMemoryReservation::~GpuMemoryReservation() {
160
161
 
161
162
  GpuResources::~GpuResources() = default;
162
163
 
164
+ bool GpuResources::supportsBFloat16CurrentDevice() {
165
+ return supportsBFloat16(getCurrentDevice());
166
+ }
167
+
163
168
  cublasHandle_t GpuResources::getBlasHandleCurrentDevice() {
164
169
  return getBlasHandle(getCurrentDevice());
165
170
  }
@@ -168,7 +173,7 @@ cudaStream_t GpuResources::getDefaultStreamCurrentDevice() {
168
173
  return getDefaultStream(getCurrentDevice());
169
174
  }
170
175
 
171
- #if defined USE_NVIDIA_RAFT
176
+ #if defined USE_NVIDIA_CUVS
172
177
  raft::device_resources& GpuResources::getRaftHandleCurrentDevice() {
173
178
  return getRaftHandle(getCurrentDevice());
174
179
  }
@@ -1,3 +1,4 @@
1
+ // @lint-ignore-every LICENSELINT
1
2
  /**
2
3
  * Copyright (c) Facebook, Inc. and its affiliates.
3
4
  *
@@ -5,7 +6,7 @@
5
6
  * LICENSE file in the root directory of this source tree.
6
7
  */
7
8
  /*
8
- * Copyright (c) 2023, NVIDIA CORPORATION.
9
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION.
9
10
  *
10
11
  * Licensed under the Apache License, Version 2.0 (the "License");
11
12
  * you may not use this file except in compliance with the License.
@@ -30,7 +31,7 @@
30
31
  #include <utility>
31
32
  #include <vector>
32
33
 
33
- #if defined USE_NVIDIA_RAFT
34
+ #if defined USE_NVIDIA_CUVS
34
35
  #include <raft/core/device_resources.hpp>
35
36
  #include <rmm/mr/device/device_memory_resource.hpp>
36
37
  #endif
@@ -161,7 +162,7 @@ struct AllocRequest : public AllocInfo {
161
162
  /// The size in bytes of the allocation
162
163
  size_t size = 0;
163
164
 
164
- #if defined USE_NVIDIA_RAFT
165
+ #if defined USE_NVIDIA_CUVS
165
166
  rmm::mr::device_memory_resource* mr = nullptr;
166
167
  #endif
167
168
  };
@@ -204,6 +205,9 @@ class GpuResources {
204
205
  /// of demand
205
206
  virtual void initializeForDevice(int device) = 0;
206
207
 
208
+ /// Does the given GPU support bfloat16?
209
+ virtual bool supportsBFloat16(int device) = 0;
210
+
207
211
  /// Returns the cuBLAS handle that we use for the given device
208
212
  virtual cublasHandle_t getBlasHandle(int device) = 0;
209
213
 
@@ -211,7 +215,7 @@ class GpuResources {
211
215
  /// given device
212
216
  virtual cudaStream_t getDefaultStream(int device) = 0;
213
217
 
214
- #if defined USE_NVIDIA_RAFT
218
+ #if defined USE_NVIDIA_CUVS
215
219
  /// Returns the raft handle for the given device which can be used to
216
220
  /// make calls to other raft primitives.
217
221
  virtual raft::device_resources& getRaftHandle(int device) = 0;
@@ -251,6 +255,9 @@ class GpuResources {
251
255
  /// Functions provided by default
252
256
  ///
253
257
 
258
+ /// Does the current GPU support bfloat16?
259
+ bool supportsBFloat16CurrentDevice();
260
+
254
261
  /// Calls getBlasHandle with the current device
255
262
  cublasHandle_t getBlasHandleCurrentDevice();
256
263
 
@@ -1,3 +1,4 @@
1
+ // @lint-ignore-every LICENSELINT
1
2
  /**
2
3
  * Copyright (c) Facebook, Inc. and its affiliates.
3
4
  *
@@ -5,7 +6,7 @@
5
6
  * LICENSE file in the root directory of this source tree.
6
7
  */
7
8
  /*
8
- * Copyright (c) 2023, NVIDIA CORPORATION.
9
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION.
9
10
  *
10
11
  * Licensed under the Apache License, Version 2.0 (the "License");
11
12
  * you may not use this file except in compliance with the License.
@@ -20,7 +21,7 @@
20
21
  * limitations under the License.
21
22
  */
22
23
 
23
- #if defined USE_NVIDIA_RAFT
24
+ #if defined USE_NVIDIA_CUVS
24
25
  #include <raft/core/device_resources.hpp>
25
26
  #include <rmm/mr/device/managed_memory_resource.hpp>
26
27
  #include <rmm/mr/device/per_device_resource.hpp>
@@ -90,7 +91,7 @@ std::string allocsToString(const std::unordered_map<void*, AllocRequest>& map) {
90
91
 
91
92
  StandardGpuResourcesImpl::StandardGpuResourcesImpl()
92
93
  :
93
- #if defined USE_NVIDIA_RAFT
94
+ #if defined USE_NVIDIA_CUVS
94
95
  mmr_(new rmm::mr::managed_memory_resource),
95
96
  pmr_(new rmm::mr::pinned_memory_resource),
96
97
  #endif
@@ -129,6 +130,10 @@ StandardGpuResourcesImpl::~StandardGpuResourcesImpl() {
129
130
  FAISS_ASSERT_MSG(
130
131
  !allocError, "GPU memory allocations not properly cleaned up");
131
132
 
133
+ #if defined USE_NVIDIA_CUVS
134
+ raftHandles_.clear();
135
+ #endif
136
+
132
137
  for (auto& entry : defaultStreams_) {
133
138
  DeviceScope scope(entry.first);
134
139
 
@@ -158,7 +163,7 @@ StandardGpuResourcesImpl::~StandardGpuResourcesImpl() {
158
163
  }
159
164
 
160
165
  if (pinnedMemAlloc_) {
161
- #if defined USE_NVIDIA_RAFT
166
+ #if defined USE_NVIDIA_CUVS
162
167
  pmr_->deallocate(pinnedMemAlloc_, pinnedMemAllocSize_);
163
168
  #else
164
169
  auto err = cudaFreeHost(pinnedMemAlloc_);
@@ -201,6 +206,13 @@ size_t StandardGpuResourcesImpl::getDefaultTempMemForGPU(
201
206
  return requested;
202
207
  }
203
208
 
209
+ /// Does the given GPU support bfloat16?
210
+ bool StandardGpuResourcesImpl::supportsBFloat16(int device) {
211
+ initializeForDevice(device);
212
+ auto& prop = getDeviceProperties(device);
213
+ return prop.major >= 8;
214
+ }
215
+
204
216
  void StandardGpuResourcesImpl::noTempMemory() {
205
217
  setTempMemory(0);
206
218
  }
@@ -257,6 +269,14 @@ void StandardGpuResourcesImpl::setDefaultStream(
257
269
  if (prevStream != stream) {
258
270
  streamWait({stream}, {prevStream});
259
271
  }
272
+ #if defined USE_NVIDIA_CUVS
273
+ // delete the raft handle for this device, which will be initialized
274
+ // with the updated stream during any subsequent calls to getRaftHandle
275
+ auto it2 = raftHandles_.find(device);
276
+ if (it2 != raftHandles_.end()) {
277
+ raft::resource::set_cuda_stream(it2->second, stream);
278
+ }
279
+ #endif
260
280
  }
261
281
 
262
282
  userDefaultStreams_[device] = stream;
@@ -274,6 +294,24 @@ void StandardGpuResourcesImpl::revertDefaultStream(int device) {
274
294
  cudaStream_t newStream = defaultStreams_[device];
275
295
 
276
296
  streamWait({newStream}, {prevStream});
297
+
298
+ #if defined USE_NVIDIA_CUVS
299
+ // update the stream on the raft handle for this device
300
+ auto it2 = raftHandles_.find(device);
301
+ if (it2 != raftHandles_.end()) {
302
+ raft::resource::set_cuda_stream(it2->second, newStream);
303
+ }
304
+ #endif
305
+ } else {
306
+ #if defined USE_NVIDIA_CUVS
307
+ // delete the raft handle for this device, which will be initialized
308
+ // with the updated stream during any subsequent calls to
309
+ // getRaftHandle
310
+ auto it2 = raftHandles_.find(device);
311
+ if (it2 != raftHandles_.end()) {
312
+ raftHandles_.erase(it2);
313
+ }
314
+ #endif
277
315
  }
278
316
  }
279
317
 
@@ -307,7 +345,7 @@ void StandardGpuResourcesImpl::initializeForDevice(int device) {
307
345
  // If this is the first device that we're initializing, create our
308
346
  // pinned memory allocation
309
347
  if (defaultStreams_.empty() && pinnedMemSize_ > 0) {
310
- #if defined USE_NVIDIA_RAFT
348
+ #if defined USE_NVIDIA_CUVS
311
349
  // If this is the first device that we're initializing, create our
312
350
  // pinned memory allocation
313
351
  if (defaultStreams_.empty() && pinnedMemSize_ > 0) {
@@ -347,11 +385,20 @@ void StandardGpuResourcesImpl::initializeForDevice(int device) {
347
385
  prop.major,
348
386
  prop.minor);
349
387
 
388
+ #if USE_AMD_ROCM
389
+ // Our code is pre-built with and expects warpSize == 32 or 64, validate
390
+ // that
391
+ FAISS_ASSERT_FMT(
392
+ prop.warpSize == 32 || prop.warpSize == 64,
393
+ "Device id %d does not have expected warpSize of 32 or 64",
394
+ device);
395
+ #else
350
396
  // Our code is pre-built with and expects warpSize == 32, validate that
351
397
  FAISS_ASSERT_FMT(
352
398
  prop.warpSize == 32,
353
399
  "Device id %d does not have expected warpSize of 32",
354
400
  device);
401
+ #endif
355
402
 
356
403
  // Create streams
357
404
  cudaStream_t defaultStream = nullptr;
@@ -360,7 +407,7 @@ void StandardGpuResourcesImpl::initializeForDevice(int device) {
360
407
 
361
408
  defaultStreams_[device] = defaultStream;
362
409
 
363
- #if defined USE_NVIDIA_RAFT
410
+ #if defined USE_NVIDIA_CUVS
364
411
  raftHandles_.emplace(std::make_pair(device, defaultStream));
365
412
  #endif
366
413
 
@@ -426,7 +473,7 @@ cudaStream_t StandardGpuResourcesImpl::getDefaultStream(int device) {
426
473
  return defaultStreams_[device];
427
474
  }
428
475
 
429
- #if defined USE_NVIDIA_RAFT
476
+ #if defined USE_NVIDIA_CUVS
430
477
  raft::device_resources& StandardGpuResourcesImpl::getRaftHandle(int device) {
431
478
  initializeForDevice(device);
432
479
 
@@ -497,7 +544,7 @@ void* StandardGpuResourcesImpl::allocMemory(const AllocRequest& req) {
497
544
  // Otherwise, we can handle this locally
498
545
  p = tempMemory_[adjReq.device]->allocMemory(adjReq.stream, adjReq.size);
499
546
  } else if (adjReq.space == MemorySpace::Device) {
500
- #if defined USE_NVIDIA_RAFT
547
+ #if defined USE_NVIDIA_CUVS
501
548
  try {
502
549
  rmm::mr::device_memory_resource* current_mr =
503
550
  rmm::mr::get_per_device_resource(
@@ -531,7 +578,7 @@ void* StandardGpuResourcesImpl::allocMemory(const AllocRequest& req) {
531
578
  }
532
579
  #endif
533
580
  } else if (adjReq.space == MemorySpace::Unified) {
534
- #if defined USE_NVIDIA_RAFT
581
+ #if defined USE_NVIDIA_CUVS
535
582
  try {
536
583
  // for now, use our own managed MR to do Unified Memory allocations.
537
584
  // TODO: change this to use the current device resource once RMM has
@@ -600,7 +647,7 @@ void StandardGpuResourcesImpl::deallocMemory(int device, void* p) {
600
647
  } else if (
601
648
  req.space == MemorySpace::Device ||
602
649
  req.space == MemorySpace::Unified) {
603
- #if defined USE_NVIDIA_RAFT
650
+ #if defined USE_NVIDIA_CUVS
604
651
  req.mr->deallocate_async(p, req.size, req.stream);
605
652
  #else
606
653
  auto err = cudaFree(p);
@@ -661,6 +708,14 @@ std::shared_ptr<GpuResources> StandardGpuResources::getResources() {
661
708
  return res_;
662
709
  }
663
710
 
711
+ bool StandardGpuResources::supportsBFloat16(int device) {
712
+ return res_->supportsBFloat16(device);
713
+ }
714
+
715
+ bool StandardGpuResources::supportsBFloat16CurrentDevice() {
716
+ return res_->supportsBFloat16CurrentDevice();
717
+ }
718
+
664
719
  void StandardGpuResources::noTempMemory() {
665
720
  res_->noTempMemory();
666
721
  }
@@ -694,7 +749,7 @@ cudaStream_t StandardGpuResources::getDefaultStream(int device) {
694
749
  return res_->getDefaultStream(device);
695
750
  }
696
751
 
697
- #if defined USE_NVIDIA_RAFT
752
+ #if defined USE_NVIDIA_CUVS
698
753
  raft::device_resources& StandardGpuResources::getRaftHandle(int device) {
699
754
  return res_->getRaftHandle(device);
700
755
  }
@@ -1,3 +1,4 @@
1
+ // @lint-ignore-every LICENSELINT
1
2
  /**
2
3
  * Copyright (c) Facebook, Inc. and its affiliates.
3
4
  *
@@ -5,7 +6,7 @@
5
6
  * LICENSE file in the root directory of this source tree.
6
7
  */
7
8
  /*
8
- * Copyright (c) 2023, NVIDIA CORPORATION.
9
+ * Copyright (c) 2023-2024, NVIDIA CORPORATION.
9
10
  *
10
11
  * Licensed under the Apache License, Version 2.0 (the "License");
11
12
  * you may not use this file except in compliance with the License.
@@ -22,7 +23,7 @@
22
23
 
23
24
  #pragma once
24
25
 
25
- #if defined USE_NVIDIA_RAFT
26
+ #if defined USE_NVIDIA_CUVS
26
27
  #include <raft/core/device_resources.hpp>
27
28
  #include <rmm/mr/host/pinned_memory_resource.hpp>
28
29
  #endif
@@ -47,6 +48,9 @@ class StandardGpuResourcesImpl : public GpuResources {
47
48
 
48
49
  ~StandardGpuResourcesImpl() override;
49
50
 
51
+ /// Does the given GPU support bfloat16?
52
+ bool supportsBFloat16(int device) override;
53
+
50
54
  /// Disable allocation of temporary memory; all temporary memory
51
55
  /// requests will call cudaMalloc / cudaFree at the point of use
52
56
  void noTempMemory();
@@ -79,7 +83,7 @@ class StandardGpuResourcesImpl : public GpuResources {
79
83
  /// this stream upon exit from an index or other Faiss GPU call.
80
84
  cudaStream_t getDefaultStream(int device) override;
81
85
 
82
- #if defined USE_NVIDIA_RAFT
86
+ #if defined USE_NVIDIA_CUVS
83
87
  /// Returns the raft handle for the given device which can be used to
84
88
  /// make calls to other raft primitives.
85
89
  raft::device_resources& getRaftHandle(int device) override;
@@ -151,7 +155,7 @@ class StandardGpuResourcesImpl : public GpuResources {
151
155
  /// cuBLAS handle for each device
152
156
  std::unordered_map<int, cublasHandle_t> blasHandles_;
153
157
 
154
- #if defined USE_NVIDIA_RAFT
158
+ #if defined USE_NVIDIA_CUVS
155
159
  /// raft handle for each device
156
160
  std::unordered_map<int, raft::device_resources> raftHandles_;
157
161
 
@@ -198,6 +202,12 @@ class StandardGpuResources : public GpuResourcesProvider {
198
202
 
199
203
  std::shared_ptr<GpuResources> getResources() override;
200
204
 
205
+ /// Whether or not the given device supports native bfloat16 arithmetic
206
+ bool supportsBFloat16(int device);
207
+
208
+ /// Whether or not the current device supports native bfloat16 arithmetic
209
+ bool supportsBFloat16CurrentDevice();
210
+
201
211
  /// Disable allocation of temporary memory; all temporary memory
202
212
  /// requests will call cudaMalloc / cudaFree at the point of use
203
213
  void noTempMemory();
@@ -234,7 +244,7 @@ class StandardGpuResources : public GpuResourcesProvider {
234
244
  /// Returns the current default stream
235
245
  cudaStream_t getDefaultStream(int device);
236
246
 
237
- #if defined USE_NVIDIA_RAFT
247
+ #if defined USE_NVIDIA_CUVS
238
248
  /// Returns the raft handle for the given device which can be used to
239
249
  /// make calls to other raft primitives.
240
250
  raft::device_resources& getRaftHandle(int device);
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,11 +1,12 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
8
  #include <faiss/gpu/impl/InterleavedCodes.h>
9
+ #include <faiss/gpu/utils/DeviceUtils.h>
9
10
  #include <faiss/gpu/utils/StaticUtils.h>
10
11
  #include <faiss/impl/FaissAssert.h>
11
12
 
@@ -166,15 +167,16 @@ void unpackInterleavedWord(
166
167
  int numVecs,
167
168
  int dims,
168
169
  int bitsPerCode) {
169
- int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
170
+ int warpSize = getWarpSizeCurrentDevice();
171
+ int wordsPerDimBlock = (size_t)warpSize * bitsPerCode / (8 * sizeof(T));
170
172
  int wordsPerBlock = wordsPerDimBlock * dims;
171
- int numBlocks = utils::divUp(numVecs, 32);
173
+ int numBlocks = utils::divUp(numVecs, warpSize);
172
174
 
173
175
  #pragma omp parallel for
174
176
  for (int i = 0; i < numVecs; ++i) {
175
- int block = i / 32;
177
+ int block = i / warpSize;
176
178
  FAISS_ASSERT(block < numBlocks);
177
- int lane = i % 32;
179
+ int lane = i % warpSize;
178
180
 
179
181
  for (int j = 0; j < dims; ++j) {
180
182
  int srcOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
@@ -188,9 +190,10 @@ std::vector<uint8_t> unpackInterleaved(
188
190
  int numVecs,
189
191
  int dims,
190
192
  int bitsPerCode) {
191
- int bytesPerDimBlock = 32 * bitsPerCode / 8;
193
+ int warpSize = getWarpSizeCurrentDevice();
194
+ int bytesPerDimBlock = warpSize * bitsPerCode / 8;
192
195
  int bytesPerBlock = bytesPerDimBlock * dims;
193
- int numBlocks = utils::divUp(numVecs, 32);
196
+ int numBlocks = utils::divUp(numVecs, warpSize);
194
197
  size_t totalSize = (size_t)bytesPerBlock * numBlocks;
195
198
  FAISS_ASSERT(data.size() == totalSize);
196
199
 
@@ -217,8 +220,8 @@ std::vector<uint8_t> unpackInterleaved(
217
220
  } else if (bitsPerCode == 4) {
218
221
  #pragma omp parallel for
219
222
  for (int i = 0; i < numVecs; ++i) {
220
- int block = i / 32;
221
- int lane = i % 32;
223
+ int block = i / warpSize;
224
+ int lane = i % warpSize;
222
225
 
223
226
  int word = lane / 2;
224
227
  int subWord = lane % 2;
@@ -235,8 +238,8 @@ std::vector<uint8_t> unpackInterleaved(
235
238
  } else if (bitsPerCode == 5) {
236
239
  #pragma omp parallel for
237
240
  for (int i = 0; i < numVecs; ++i) {
238
- int block = i / 32;
239
- int blockVector = i % 32;
241
+ int block = i / warpSize;
242
+ int blockVector = i % warpSize;
240
243
 
241
244
  for (int j = 0; j < dims; ++j) {
242
245
  uint8_t* dimBlock =
@@ -257,8 +260,8 @@ std::vector<uint8_t> unpackInterleaved(
257
260
  } else if (bitsPerCode == 6) {
258
261
  #pragma omp parallel for
259
262
  for (int i = 0; i < numVecs; ++i) {
260
- int block = i / 32;
261
- int blockVector = i % 32;
263
+ int block = i / warpSize;
264
+ int blockVector = i % warpSize;
262
265
 
263
266
  for (int j = 0; j < dims; ++j) {
264
267
  uint8_t* dimBlock =
@@ -442,17 +445,18 @@ void packInterleavedWord(
442
445
  int numVecs,
443
446
  int dims,
444
447
  int bitsPerCode) {
445
- int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
448
+ int warpSize = getWarpSizeCurrentDevice();
449
+ int wordsPerDimBlock = (size_t)warpSize * bitsPerCode / (8 * sizeof(T));
446
450
  int wordsPerBlock = wordsPerDimBlock * dims;
447
- int numBlocks = utils::divUp(numVecs, 32);
451
+ int numBlocks = utils::divUp(numVecs, warpSize);
448
452
 
449
453
  // We're guaranteed that all other slots not filled by the vectors present
450
454
  // are initialized to zero (from the vector constructor in packInterleaved)
451
455
  #pragma omp parallel for
452
456
  for (int i = 0; i < numVecs; ++i) {
453
- int block = i / 32;
457
+ int block = i / warpSize;
454
458
  FAISS_ASSERT(block < numBlocks);
455
- int lane = i % 32;
459
+ int lane = i % warpSize;
456
460
 
457
461
  for (int j = 0; j < dims; ++j) {
458
462
  int dstOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
@@ -466,9 +470,10 @@ std::vector<uint8_t> packInterleaved(
466
470
  int numVecs,
467
471
  int dims,
468
472
  int bitsPerCode) {
469
- int bytesPerDimBlock = 32 * bitsPerCode / 8;
473
+ int warpSize = getWarpSizeCurrentDevice();
474
+ int bytesPerDimBlock = warpSize * bitsPerCode / 8;
470
475
  int bytesPerBlock = bytesPerDimBlock * dims;
471
- int numBlocks = utils::divUp(numVecs, 32);
476
+ int numBlocks = utils::divUp(numVecs, warpSize);
472
477
  size_t totalSize = (size_t)bytesPerBlock * numBlocks;
473
478
 
474
479
  // bit codes padded to whole bytes
@@ -499,7 +504,7 @@ std::vector<uint8_t> packInterleaved(
499
504
  for (int i = 0; i < numBlocks; ++i) {
500
505
  for (int j = 0; j < dims; ++j) {
501
506
  for (int k = 0; k < bytesPerDimBlock; ++k) {
502
- int loVec = i * 32 + k * 2;
507
+ int loVec = i * warpSize + k * 2;
503
508
  int hiVec = loVec + 1;
504
509
 
505
510
  uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
@@ -516,7 +521,7 @@ std::vector<uint8_t> packInterleaved(
516
521
  for (int j = 0; j < dims; ++j) {
517
522
  for (int k = 0; k < bytesPerDimBlock; ++k) {
518
523
  // What input vectors we are pulling from
519
- int loVec = i * 32 + (k * 8) / 5;
524
+ int loVec = i * warpSize + (k * 8) / 5;
520
525
  int hiVec = loVec + 1;
521
526
  int hiVec2 = hiVec + 1;
522
527
 
@@ -536,7 +541,7 @@ std::vector<uint8_t> packInterleaved(
536
541
  for (int j = 0; j < dims; ++j) {
537
542
  for (int k = 0; k < bytesPerDimBlock; ++k) {
538
543
  // What input vectors we are pulling from
539
- int loVec = i * 32 + (k * 8) / 6;
544
+ int loVec = i * warpSize + (k * 8) / 6;
540
545
  int hiVec = loVec + 1;
541
546
 
542
547
  uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -17,6 +17,7 @@
17
17
  #include <vector>
18
18
 
19
19
  #include <cuda_profiler_api.h>
20
+ #include <faiss/impl/AuxIndexStructures.h>
20
21
 
21
22
  DEFINE_int32(num, 10000, "# of vecs");
22
23
  DEFINE_int32(k, 100, "# of clusters");
@@ -34,6 +35,7 @@ DEFINE_int64(
34
35
  "minimum size to use CPU -> GPU paged copies");
35
36
  DEFINE_int64(pinned_mem, -1, "pinned memory allocation to use");
36
37
  DEFINE_int32(max_points, -1, "max points per centroid");
38
+ DEFINE_double(timeout, 0, "timeout in seconds");
37
39
 
38
40
  using namespace faiss::gpu;
39
41
 
@@ -99,10 +101,14 @@ int main(int argc, char** argv) {
99
101
  cp.max_points_per_centroid = FLAGS_max_points;
100
102
  }
101
103
 
104
+ auto tc = new faiss::TimeoutCallback();
105
+ faiss::InterruptCallback::instance.reset(tc);
106
+
102
107
  faiss::Clustering kmeans(FLAGS_dim, FLAGS_k, cp);
103
108
 
104
109
  // Time k-means
105
110
  {
111
+ tc->set_timeout(FLAGS_timeout);
106
112
  CpuTimer timer;
107
113
 
108
114
  kmeans.train(FLAGS_num, vecs.data(), *(gpuIndex.getIndex()));