faiss 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (293) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/AutoTune.cpp +2 -2
  6. data/vendor/faiss/faiss/AutoTune.h +3 -3
  7. data/vendor/faiss/faiss/Clustering.cpp +37 -6
  8. data/vendor/faiss/faiss/Clustering.h +12 -3
  9. data/vendor/faiss/faiss/IVFlib.cpp +6 -3
  10. data/vendor/faiss/faiss/IVFlib.h +2 -2
  11. data/vendor/faiss/faiss/Index.cpp +6 -2
  12. data/vendor/faiss/faiss/Index.h +30 -8
  13. data/vendor/faiss/faiss/Index2Layer.cpp +2 -2
  14. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  15. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +7 -7
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +14 -16
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +2 -2
  19. data/vendor/faiss/faiss/IndexBinary.cpp +13 -2
  20. data/vendor/faiss/faiss/IndexBinary.h +8 -2
  21. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -2
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -2
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +2 -7
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.h +3 -3
  27. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -3
  28. data/vendor/faiss/faiss/IndexBinaryHash.h +2 -2
  29. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  30. data/vendor/faiss/faiss/IndexBinaryIVF.h +2 -2
  31. data/vendor/faiss/faiss/IndexFastScan.cpp +32 -18
  32. data/vendor/faiss/faiss/IndexFastScan.h +11 -2
  33. data/vendor/faiss/faiss/IndexFlat.cpp +13 -10
  34. data/vendor/faiss/faiss/IndexFlat.h +2 -2
  35. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -7
  36. data/vendor/faiss/faiss/IndexFlatCodes.h +25 -5
  37. data/vendor/faiss/faiss/IndexHNSW.cpp +156 -96
  38. data/vendor/faiss/faiss/IndexHNSW.h +54 -5
  39. data/vendor/faiss/faiss/IndexIDMap.cpp +19 -3
  40. data/vendor/faiss/faiss/IndexIDMap.h +5 -2
  41. data/vendor/faiss/faiss/IndexIVF.cpp +5 -6
  42. data/vendor/faiss/faiss/IndexIVF.h +13 -4
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +21 -7
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +5 -2
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +3 -14
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -4
  47. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +201 -91
  48. data/vendor/faiss/faiss/IndexIVFFastScan.h +33 -9
  49. data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -2
  50. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -2
  51. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +2 -2
  52. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +2 -2
  53. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -6
  54. data/vendor/faiss/faiss/IndexIVFPQ.h +2 -2
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +7 -14
  56. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -4
  57. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -2
  58. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -2
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -3
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -2
  61. data/vendor/faiss/faiss/IndexLSH.cpp +2 -3
  62. data/vendor/faiss/faiss/IndexLSH.h +2 -2
  63. data/vendor/faiss/faiss/IndexLattice.cpp +3 -21
  64. data/vendor/faiss/faiss/IndexLattice.h +5 -24
  65. data/vendor/faiss/faiss/IndexNNDescent.cpp +2 -31
  66. data/vendor/faiss/faiss/IndexNNDescent.h +3 -3
  67. data/vendor/faiss/faiss/IndexNSG.cpp +2 -5
  68. data/vendor/faiss/faiss/IndexNSG.h +3 -3
  69. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  70. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  71. data/vendor/faiss/faiss/IndexPQ.cpp +26 -26
  72. data/vendor/faiss/faiss/IndexPQ.h +2 -2
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +2 -5
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +2 -11
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +2 -2
  76. data/vendor/faiss/faiss/IndexPreTransform.h +3 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +46 -9
  78. data/vendor/faiss/faiss/IndexRefine.h +9 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +2 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -2
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +2 -2
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +2 -2
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +5 -4
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -2
  85. data/vendor/faiss/faiss/IndexShards.cpp +2 -2
  86. data/vendor/faiss/faiss/IndexShards.h +2 -2
  87. data/vendor/faiss/faiss/IndexShardsIVF.cpp +2 -2
  88. data/vendor/faiss/faiss/IndexShardsIVF.h +2 -2
  89. data/vendor/faiss/faiss/MatrixStats.cpp +2 -2
  90. data/vendor/faiss/faiss/MatrixStats.h +2 -2
  91. data/vendor/faiss/faiss/MetaIndexes.cpp +2 -3
  92. data/vendor/faiss/faiss/MetaIndexes.h +2 -2
  93. data/vendor/faiss/faiss/MetricType.h +9 -4
  94. data/vendor/faiss/faiss/VectorTransform.cpp +2 -2
  95. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  96. data/vendor/faiss/faiss/clone_index.cpp +2 -2
  97. data/vendor/faiss/faiss/clone_index.h +2 -2
  98. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +2 -2
  99. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +2 -2
  100. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +97 -19
  101. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +192 -0
  102. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +29 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2 -2
  104. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +85 -32
  105. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2 -2
  106. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +2 -2
  107. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +2 -2
  108. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +2 -2
  109. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +2 -2
  110. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +2 -2
  111. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -5
  112. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +2 -2
  113. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +45 -13
  114. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  115. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +12 -6
  116. data/vendor/faiss/faiss/gpu/GpuDistance.h +11 -7
  117. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +3 -3
  118. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +2 -2
  119. data/vendor/faiss/faiss/gpu/GpuIndex.h +10 -15
  120. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -2
  121. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +285 -0
  122. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +2 -2
  123. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -2
  124. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +4 -2
  125. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +3 -3
  126. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +2 -2
  127. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +2 -2
  128. data/vendor/faiss/faiss/gpu/GpuResources.cpp +7 -2
  129. data/vendor/faiss/faiss/gpu/GpuResources.h +11 -4
  130. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +66 -11
  131. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +15 -5
  132. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -2
  133. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +28 -23
  134. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +2 -2
  135. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +2 -2
  136. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +2 -2
  137. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  138. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +2 -2
  139. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +8 -2
  140. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +2 -3
  141. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +2 -2
  142. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +10 -7
  143. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +2 -2
  144. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +54 -54
  145. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +144 -77
  146. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +51 -51
  147. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +2 -2
  148. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +3 -3
  149. data/vendor/faiss/faiss/gpu/test/TestGpuResidualQuantizer.cpp +70 -0
  150. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +74 -4
  151. data/vendor/faiss/faiss/gpu/test/TestUtils.h +2 -2
  152. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +3 -3
  153. data/vendor/faiss/faiss/gpu/utils/{RaftUtils.h → CuvsUtils.h} +12 -11
  154. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +8 -2
  155. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +2 -2
  156. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +2 -2
  157. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +2 -2
  158. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +6 -3
  159. data/vendor/faiss/faiss/gpu/utils/Timer.h +3 -3
  160. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +79 -11
  161. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +17 -5
  162. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +27 -2
  163. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +11 -3
  164. data/vendor/faiss/faiss/impl/CodePacker.cpp +2 -2
  165. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  166. data/vendor/faiss/faiss/impl/DistanceComputer.h +48 -2
  167. data/vendor/faiss/faiss/impl/FaissAssert.h +6 -4
  168. data/vendor/faiss/faiss/impl/FaissException.cpp +2 -2
  169. data/vendor/faiss/faiss/impl/FaissException.h +2 -3
  170. data/vendor/faiss/faiss/impl/HNSW.cpp +378 -205
  171. data/vendor/faiss/faiss/impl/HNSW.h +55 -24
  172. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  173. data/vendor/faiss/faiss/impl/IDSelector.h +2 -2
  174. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +10 -10
  175. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +2 -2
  176. data/vendor/faiss/faiss/impl/LookupTableScaler.h +36 -2
  177. data/vendor/faiss/faiss/impl/NNDescent.cpp +15 -10
  178. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  179. data/vendor/faiss/faiss/impl/NSG.cpp +26 -49
  180. data/vendor/faiss/faiss/impl/NSG.h +20 -8
  181. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +2 -2
  182. data/vendor/faiss/faiss/impl/PolysemousTraining.h +2 -2
  183. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +2 -4
  184. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +2 -2
  185. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -2
  186. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +3 -2
  187. data/vendor/faiss/faiss/impl/ProductQuantizer.h +7 -3
  188. data/vendor/faiss/faiss/impl/Quantizer.h +2 -2
  189. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +2 -36
  190. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +3 -13
  191. data/vendor/faiss/faiss/impl/ResultHandler.h +153 -34
  192. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +721 -104
  193. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +5 -2
  194. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +2 -2
  195. data/vendor/faiss/faiss/impl/ThreadedIndex.h +2 -2
  196. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +7 -2
  197. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  198. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +2 -2
  199. data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +440 -0
  200. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +55 -2
  201. data/vendor/faiss/faiss/impl/index_read.cpp +31 -20
  202. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  203. data/vendor/faiss/faiss/impl/index_write.cpp +30 -16
  204. data/vendor/faiss/faiss/impl/io.cpp +15 -7
  205. data/vendor/faiss/faiss/impl/io.h +6 -6
  206. data/vendor/faiss/faiss/impl/io_macros.h +8 -9
  207. data/vendor/faiss/faiss/impl/kmeans1d.cpp +2 -3
  208. data/vendor/faiss/faiss/impl/kmeans1d.h +2 -2
  209. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +2 -3
  210. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  211. data/vendor/faiss/faiss/impl/platform_macros.h +34 -2
  212. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +13 -2
  213. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +20 -2
  214. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +3 -3
  215. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +450 -3
  216. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +8 -8
  217. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +3 -3
  218. data/vendor/faiss/faiss/impl/simd_result_handlers.h +151 -67
  219. data/vendor/faiss/faiss/index_factory.cpp +51 -34
  220. data/vendor/faiss/faiss/index_factory.h +2 -2
  221. data/vendor/faiss/faiss/index_io.h +14 -7
  222. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +30 -10
  223. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +5 -2
  224. data/vendor/faiss/faiss/invlists/DirectMap.cpp +11 -3
  225. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -2
  226. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +57 -19
  227. data/vendor/faiss/faiss/invlists/InvertedLists.h +20 -11
  228. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +2 -2
  229. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +2 -2
  230. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +23 -9
  231. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +4 -3
  232. data/vendor/faiss/faiss/python/python_callbacks.cpp +5 -5
  233. data/vendor/faiss/faiss/python/python_callbacks.h +2 -2
  234. data/vendor/faiss/faiss/utils/AlignedTable.h +5 -3
  235. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  236. data/vendor/faiss/faiss/utils/Heap.h +107 -2
  237. data/vendor/faiss/faiss/utils/NeuralNet.cpp +346 -0
  238. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  239. data/vendor/faiss/faiss/utils/WorkerThread.cpp +2 -2
  240. data/vendor/faiss/faiss/utils/WorkerThread.h +2 -2
  241. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  242. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  243. data/vendor/faiss/faiss/utils/approx_topk/generic.h +2 -2
  244. data/vendor/faiss/faiss/utils/approx_topk/mode.h +2 -2
  245. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +2 -2
  246. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  247. data/vendor/faiss/faiss/utils/distances.cpp +249 -90
  248. data/vendor/faiss/faiss/utils/distances.h +8 -8
  249. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +2 -2
  250. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  251. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  252. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +2 -2
  253. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +2 -2
  254. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +2 -2
  255. data/vendor/faiss/faiss/utils/distances_simd.cpp +1543 -56
  256. data/vendor/faiss/faiss/utils/extra_distances-inl.h +72 -2
  257. data/vendor/faiss/faiss/utils/extra_distances.cpp +87 -140
  258. data/vendor/faiss/faiss/utils/extra_distances.h +5 -4
  259. data/vendor/faiss/faiss/utils/fp16-arm.h +2 -2
  260. data/vendor/faiss/faiss/utils/fp16-fp16c.h +2 -2
  261. data/vendor/faiss/faiss/utils/fp16-inl.h +2 -2
  262. data/vendor/faiss/faiss/utils/fp16.h +2 -2
  263. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -2
  264. data/vendor/faiss/faiss/utils/hamming.cpp +3 -4
  265. data/vendor/faiss/faiss/utils/hamming.h +2 -2
  266. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +2 -2
  267. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +490 -0
  268. data/vendor/faiss/faiss/utils/hamming_distance/common.h +2 -2
  269. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +6 -3
  270. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +7 -3
  271. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +5 -5
  272. data/vendor/faiss/faiss/utils/ordered_key_value.h +2 -2
  273. data/vendor/faiss/faiss/utils/partitioning.cpp +2 -2
  274. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  275. data/vendor/faiss/faiss/utils/prefetch.h +2 -2
  276. data/vendor/faiss/faiss/utils/quantize_lut.cpp +2 -2
  277. data/vendor/faiss/faiss/utils/quantize_lut.h +2 -2
  278. data/vendor/faiss/faiss/utils/random.cpp +45 -2
  279. data/vendor/faiss/faiss/utils/random.h +27 -2
  280. data/vendor/faiss/faiss/utils/simdlib.h +12 -3
  281. data/vendor/faiss/faiss/utils/simdlib_avx2.h +2 -2
  282. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  283. data/vendor/faiss/faiss/utils/simdlib_emulated.h +2 -2
  284. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -4
  285. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  286. data/vendor/faiss/faiss/utils/sorting.cpp +2 -2
  287. data/vendor/faiss/faiss/utils/sorting.h +2 -2
  288. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +2 -2
  289. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  290. data/vendor/faiss/faiss/utils/utils.cpp +17 -10
  291. data/vendor/faiss/faiss/utils/utils.h +7 -3
  292. metadata +22 -11
  293. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -0,0 +1,490 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #ifndef HAMMING_AVX512_INL_H
9
+ #define HAMMING_AVX512_INL_H
10
+
11
+ // AVX512 version
12
+ // The _mm512_popcnt_epi64 intrinsic is used to accelerate Hamming distance
13
+ // calculations in HammingComputerDefault and HammingComputer64. This intrinsic
14
+ // is not available in the default FAISS avx512 build mode but is only
15
+ // available in the avx512_spr build mode, which targets Intel(R) Sapphire
16
+ // Rapids.
17
+
18
+ #include <cassert>
19
+ #include <cstddef>
20
+ #include <cstdint>
21
+
22
+ #include <faiss/impl/platform_macros.h>
23
+
24
+ #include <immintrin.h>
25
+
26
+ namespace faiss {
27
+
28
+ /* Elementary Hamming distance computation: unoptimized */
29
+ template <size_t nbits, typename T>
30
+ inline T hamming(const uint8_t* bs1, const uint8_t* bs2) {
31
+ const size_t nbytes = nbits / 8;
32
+ size_t i;
33
+ T h = 0;
34
+ for (i = 0; i < nbytes; i++) {
35
+ h += (T)hamdis_tab_ham_bytes[bs1[i] ^ bs2[i]];
36
+ }
37
+ return h;
38
+ }
39
+
40
+ /* Hamming distances for multiples of 64 bits */
41
+ template <size_t nbits>
42
+ inline hamdis_t hamming(const uint64_t* bs1, const uint64_t* bs2) {
43
+ const size_t nwords = nbits / 64;
44
+ size_t i;
45
+ hamdis_t h = 0;
46
+ for (i = 0; i < nwords; i++) {
47
+ h += popcount64(bs1[i] ^ bs2[i]);
48
+ }
49
+ return h;
50
+ }
51
+
52
+ /* specialized (optimized) functions */
53
+ template <>
54
+ inline hamdis_t hamming<64>(const uint64_t* pa, const uint64_t* pb) {
55
+ return popcount64(pa[0] ^ pb[0]);
56
+ }
57
+
58
+ template <>
59
+ inline hamdis_t hamming<128>(const uint64_t* pa, const uint64_t* pb) {
60
+ return popcount64(pa[0] ^ pb[0]) + popcount64(pa[1] ^ pb[1]);
61
+ }
62
+
63
+ template <>
64
+ inline hamdis_t hamming<256>(const uint64_t* pa, const uint64_t* pb) {
65
+ return popcount64(pa[0] ^ pb[0]) + popcount64(pa[1] ^ pb[1]) +
66
+ popcount64(pa[2] ^ pb[2]) + popcount64(pa[3] ^ pb[3]);
67
+ }
68
+
69
+ /* Hamming distances for multiple of 64 bits */
70
+ inline hamdis_t hamming(
71
+ const uint64_t* bs1,
72
+ const uint64_t* bs2,
73
+ size_t nwords) {
74
+ hamdis_t h = 0;
75
+ for (size_t i = 0; i < nwords; i++) {
76
+ h += popcount64(bs1[i] ^ bs2[i]);
77
+ }
78
+ return h;
79
+ }
80
+
81
+ /******************************************************************
82
+ * The HammingComputer series of classes compares a single code of
83
+ * size 4 to 32 to incoming codes. They are intended for use as a
84
+ * template class where it would be inefficient to switch on the code
85
+ * size in the inner loop. Hopefully the compiler will inline the
86
+ * hamming() functions and put the a0, a1, ... in registers.
87
+ ******************************************************************/
88
+
89
+ struct HammingComputer4 {
90
+ uint32_t a0;
91
+
92
+ HammingComputer4() {}
93
+
94
+ HammingComputer4(const uint8_t* a, int code_size) {
95
+ set(a, code_size);
96
+ }
97
+
98
+ void set(const uint8_t* a, int code_size) {
99
+ assert(code_size == 4);
100
+ a0 = *(uint32_t*)a;
101
+ }
102
+
103
+ inline int hamming(const uint8_t* b) const {
104
+ return popcount64(*(uint32_t*)b ^ a0);
105
+ }
106
+
107
+ inline static constexpr int get_code_size() {
108
+ return 4;
109
+ }
110
+ };
111
+
112
+ struct HammingComputer8 {
113
+ uint64_t a0;
114
+
115
+ HammingComputer8() {}
116
+
117
+ HammingComputer8(const uint8_t* a, int code_size) {
118
+ set(a, code_size);
119
+ }
120
+
121
+ void set(const uint8_t* a, int code_size) {
122
+ assert(code_size == 8);
123
+ a0 = *(uint64_t*)a;
124
+ }
125
+
126
+ inline int hamming(const uint8_t* b) const {
127
+ return popcount64(*(uint64_t*)b ^ a0);
128
+ }
129
+
130
+ inline static constexpr int get_code_size() {
131
+ return 8;
132
+ }
133
+ };
134
+
135
+ struct HammingComputer16 {
136
+ uint64_t a0, a1;
137
+
138
+ HammingComputer16() {}
139
+
140
+ HammingComputer16(const uint8_t* a8, int code_size) {
141
+ set(a8, code_size);
142
+ }
143
+
144
+ void set(const uint8_t* a8, int code_size) {
145
+ assert(code_size == 16);
146
+ const uint64_t* a = (uint64_t*)a8;
147
+ a0 = a[0];
148
+ a1 = a[1];
149
+ }
150
+
151
+ inline int hamming(const uint8_t* b8) const {
152
+ const uint64_t* b = (uint64_t*)b8;
153
+ return popcount64(b[0] ^ a0) + popcount64(b[1] ^ a1);
154
+ }
155
+
156
+ inline static constexpr int get_code_size() {
157
+ return 16;
158
+ }
159
+ };
160
+
161
+ // when applied to an array, 1/2 of the 64-bit accesses are unaligned.
162
+ // This incurs a penalty of ~10% wrt. fully aligned accesses.
163
+ struct HammingComputer20 {
164
+ uint64_t a0, a1;
165
+ uint32_t a2;
166
+
167
+ HammingComputer20() {}
168
+
169
+ HammingComputer20(const uint8_t* a8, int code_size) {
170
+ set(a8, code_size);
171
+ }
172
+
173
+ void set(const uint8_t* a8, int code_size) {
174
+ assert(code_size == 20);
175
+ const uint64_t* a = (uint64_t*)a8;
176
+ a0 = a[0];
177
+ a1 = a[1];
178
+ a2 = a[2];
179
+ }
180
+
181
+ inline int hamming(const uint8_t* b8) const {
182
+ const uint64_t* b = (uint64_t*)b8;
183
+ return popcount64(b[0] ^ a0) + popcount64(b[1] ^ a1) +
184
+ popcount64(*(uint32_t*)(b + 2) ^ a2);
185
+ }
186
+
187
+ inline static constexpr int get_code_size() {
188
+ return 20;
189
+ }
190
+ };
191
+
192
+ struct HammingComputer32 {
193
+ uint64_t a0, a1, a2, a3;
194
+
195
+ HammingComputer32() {}
196
+
197
+ HammingComputer32(const uint8_t* a8, int code_size) {
198
+ set(a8, code_size);
199
+ }
200
+
201
+ void set(const uint8_t* a8, int code_size) {
202
+ assert(code_size == 32);
203
+ const uint64_t* a = (uint64_t*)a8;
204
+ a0 = a[0];
205
+ a1 = a[1];
206
+ a2 = a[2];
207
+ a3 = a[3];
208
+ }
209
+
210
+ inline int hamming(const uint8_t* b8) const {
211
+ const uint64_t* b = (uint64_t*)b8;
212
+ return popcount64(b[0] ^ a0) + popcount64(b[1] ^ a1) +
213
+ popcount64(b[2] ^ a2) + popcount64(b[3] ^ a3);
214
+ }
215
+
216
+ inline static constexpr int get_code_size() {
217
+ return 32;
218
+ }
219
+ };
220
+
221
+ struct HammingComputer64 {
222
+ uint64_t a0, a1, a2, a3, a4, a5, a6, a7;
223
+ const uint64_t* a;
224
+
225
+ HammingComputer64() {}
226
+
227
+ HammingComputer64(const uint8_t* a8, int code_size) {
228
+ set(a8, code_size);
229
+ }
230
+
231
+ void set(const uint8_t* a8, int code_size) {
232
+ assert(code_size == 64);
233
+ a = (uint64_t*)a8;
234
+ a0 = a[0];
235
+ a1 = a[1];
236
+ a2 = a[2];
237
+ a3 = a[3];
238
+ a4 = a[4];
239
+ a5 = a[5];
240
+ a6 = a[6];
241
+ a7 = a[7];
242
+ }
243
+
244
+ inline int hamming(const uint8_t* b8) const {
245
+ const uint64_t* b = (uint64_t*)b8;
246
+ #ifdef __AVX512VPOPCNTDQ__
247
+ __m512i vxor =
248
+ _mm512_xor_si512(_mm512_loadu_si512(a), _mm512_loadu_si512(b));
249
+ __m512i vpcnt = _mm512_popcnt_epi64(vxor);
250
+ // reduce performs better than adding the lower and higher parts
251
+ return _mm512_reduce_add_epi32(vpcnt);
252
+ #else
253
+ return popcount64(b[0] ^ a0) + popcount64(b[1] ^ a1) +
254
+ popcount64(b[2] ^ a2) + popcount64(b[3] ^ a3) +
255
+ popcount64(b[4] ^ a4) + popcount64(b[5] ^ a5) +
256
+ popcount64(b[6] ^ a6) + popcount64(b[7] ^ a7);
257
+ #endif
258
+ }
259
+
260
+ inline static constexpr int get_code_size() {
261
+ return 64;
262
+ }
263
+ };
264
+
265
+ struct HammingComputerDefault {
266
+ const uint8_t* a8;
267
+ int quotient8;
268
+ int remainder8;
269
+
270
+ HammingComputerDefault() {}
271
+
272
+ HammingComputerDefault(const uint8_t* a8, int code_size) {
273
+ set(a8, code_size);
274
+ }
275
+
276
+ void set(const uint8_t* a8_2, int code_size) {
277
+ this->a8 = a8_2;
278
+ quotient8 = code_size / 8;
279
+ remainder8 = code_size % 8;
280
+ }
281
+
282
+ int hamming(const uint8_t* b8) const {
283
+ int accu = 0;
284
+
285
+ const uint64_t* a64 = reinterpret_cast<const uint64_t*>(a8);
286
+ const uint64_t* b64 = reinterpret_cast<const uint64_t*>(b8);
287
+
288
+ int i = 0;
289
+ #ifdef __AVX512VPOPCNTDQ__
290
+ int quotient64 = quotient8 / 8;
291
+ for (; i < quotient64; ++i) {
292
+ __m512i vxor = _mm512_xor_si512(
293
+ _mm512_loadu_si512(&a64[i * 8]),
294
+ _mm512_loadu_si512(&b64[i * 8]));
295
+ __m512i vpcnt = _mm512_popcnt_epi64(vxor);
296
+ // reduce performs better than adding the lower and higher parts
297
+ accu += _mm512_reduce_add_epi32(vpcnt);
298
+ }
299
+ i *= 8;
300
+ #endif
301
+ int len = quotient8 - i;
302
+ switch (len & 7) {
303
+ default:
304
+ while (len > 7) {
305
+ len -= 8;
306
+ accu += popcount64(a64[i] ^ b64[i]);
307
+ i++;
308
+ [[fallthrough]];
309
+ case 7:
310
+ accu += popcount64(a64[i] ^ b64[i]);
311
+ i++;
312
+ [[fallthrough]];
313
+ case 6:
314
+ accu += popcount64(a64[i] ^ b64[i]);
315
+ i++;
316
+ [[fallthrough]];
317
+ case 5:
318
+ accu += popcount64(a64[i] ^ b64[i]);
319
+ i++;
320
+ [[fallthrough]];
321
+ case 4:
322
+ accu += popcount64(a64[i] ^ b64[i]);
323
+ i++;
324
+ [[fallthrough]];
325
+ case 3:
326
+ accu += popcount64(a64[i] ^ b64[i]);
327
+ i++;
328
+ [[fallthrough]];
329
+ case 2:
330
+ accu += popcount64(a64[i] ^ b64[i]);
331
+ i++;
332
+ [[fallthrough]];
333
+ case 1:
334
+ accu += popcount64(a64[i] ^ b64[i]);
335
+ i++;
336
+ }
337
+ }
338
+ if (remainder8) {
339
+ const uint8_t* a = a8 + 8 * quotient8;
340
+ const uint8_t* b = b8 + 8 * quotient8;
341
+ switch (remainder8) {
342
+ case 7:
343
+ accu += hamdis_tab_ham_bytes[a[6] ^ b[6]];
344
+ [[fallthrough]];
345
+ case 6:
346
+ accu += hamdis_tab_ham_bytes[a[5] ^ b[5]];
347
+ [[fallthrough]];
348
+ case 5:
349
+ accu += hamdis_tab_ham_bytes[a[4] ^ b[4]];
350
+ [[fallthrough]];
351
+ case 4:
352
+ accu += hamdis_tab_ham_bytes[a[3] ^ b[3]];
353
+ [[fallthrough]];
354
+ case 3:
355
+ accu += hamdis_tab_ham_bytes[a[2] ^ b[2]];
356
+ [[fallthrough]];
357
+ case 2:
358
+ accu += hamdis_tab_ham_bytes[a[1] ^ b[1]];
359
+ [[fallthrough]];
360
+ case 1:
361
+ accu += hamdis_tab_ham_bytes[a[0] ^ b[0]];
362
+ [[fallthrough]];
363
+ default:
364
+ break;
365
+ }
366
+ }
367
+
368
+ return accu;
369
+ }
370
+
371
+ inline int get_code_size() const {
372
+ return quotient8 * 8 + remainder8;
373
+ }
374
+ };
375
+
376
+ /***************************************************************************
377
+ * generalized Hamming = number of bytes that are different between
378
+ * two codes.
379
+ ***************************************************************************/
380
+
381
+ inline int generalized_hamming_64(uint64_t a) {
382
+ a |= a >> 1;
383
+ a |= a >> 2;
384
+ a |= a >> 4;
385
+ a &= 0x0101010101010101UL;
386
+ return popcount64(a);
387
+ }
388
+
389
+ struct GenHammingComputer8 {
390
+ uint64_t a0;
391
+
392
+ GenHammingComputer8(const uint8_t* a, int code_size) {
393
+ assert(code_size == 8);
394
+ a0 = *(uint64_t*)a;
395
+ }
396
+
397
+ inline int hamming(const uint8_t* b) const {
398
+ return generalized_hamming_64(*(uint64_t*)b ^ a0);
399
+ }
400
+
401
+ inline static constexpr int get_code_size() {
402
+ return 8;
403
+ }
404
+ };
405
+
406
+ // I'm not sure whether this version is faster of slower, tbh
407
+ // todo: test on different CPUs
408
+ struct GenHammingComputer16 {
409
+ __m128i a;
410
+
411
+ GenHammingComputer16(const uint8_t* a8, int code_size) {
412
+ assert(code_size == 16);
413
+ a = _mm_loadu_si128((const __m128i_u*)a8);
414
+ }
415
+
416
+ inline int hamming(const uint8_t* b8) const {
417
+ const __m128i b = _mm_loadu_si128((const __m128i_u*)b8);
418
+ const __m128i cmp = _mm_cmpeq_epi8(a, b);
419
+ const auto movemask = _mm_movemask_epi8(cmp);
420
+ return 16 - popcount32(movemask);
421
+ }
422
+
423
+ inline static constexpr int get_code_size() {
424
+ return 16;
425
+ }
426
+ };
427
+
428
+ struct GenHammingComputer32 {
429
+ __m256i a;
430
+
431
+ GenHammingComputer32(const uint8_t* a8, int code_size) {
432
+ assert(code_size == 32);
433
+ a = _mm256_loadu_si256((const __m256i_u*)a8);
434
+ }
435
+
436
+ inline int hamming(const uint8_t* b8) const {
437
+ const __m256i b = _mm256_loadu_si256((const __m256i_u*)b8);
438
+ const __m256i cmp = _mm256_cmpeq_epi8(a, b);
439
+ const uint32_t movemask = _mm256_movemask_epi8(cmp);
440
+ return 32 - popcount32(movemask);
441
+ }
442
+
443
+ inline static constexpr int get_code_size() {
444
+ return 32;
445
+ }
446
+ };
447
+
448
+ // A specialized version might be needed for the very long
449
+ // GenHamming code_size. In such a case, one may accumulate
450
+ // counts using _mm256_sub_epi8 and then compute a horizontal
451
+ // sum (using _mm256_sad_epu8, maybe, in blocks of no larger
452
+ // than 256 * 32 bytes).
453
+
454
+ struct GenHammingComputerM8 {
455
+ const uint64_t* a;
456
+ int n;
457
+
458
+ GenHammingComputerM8(const uint8_t* a8, int code_size) {
459
+ assert(code_size % 8 == 0);
460
+ a = (uint64_t*)a8;
461
+ n = code_size / 8;
462
+ }
463
+
464
+ int hamming(const uint8_t* b8) const {
465
+ const uint64_t* b = (uint64_t*)b8;
466
+ int accu = 0;
467
+
468
+ int i = 0;
469
+ int n4 = (n / 4) * 4;
470
+ for (; i < n4; i += 4) {
471
+ const __m256i av = _mm256_loadu_si256((const __m256i_u*)(a + i));
472
+ const __m256i bv = _mm256_loadu_si256((const __m256i_u*)(b + i));
473
+ const __m256i cmp = _mm256_cmpeq_epi8(av, bv);
474
+ const uint32_t movemask = _mm256_movemask_epi8(cmp);
475
+ accu += 32 - popcount32(movemask);
476
+ }
477
+
478
+ for (; i < n; i++)
479
+ accu += generalized_hamming_64(a[i] ^ b[i]);
480
+ return accu;
481
+ }
482
+
483
+ inline int get_code_size() const {
484
+ return n * 8;
485
+ }
486
+ };
487
+
488
+ } // namespace faiss
489
+
490
+ #endif
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -166,9 +166,12 @@ struct HammingComputer20 {
166
166
  void set(const uint8_t* a8, int code_size) {
167
167
  assert(code_size == 20);
168
168
  const uint64_t* a = (uint64_t*)a8;
169
+ const uint32_t* b = (uint32_t*)a8;
169
170
  a0 = a[0];
170
171
  a1 = a[1];
171
- a2 = a[2];
172
+ // can't read a[2] since it is uint64_t, not uint32_t
173
+ // results in AddressSanitizer failure reading past end of array
174
+ a2 = b[4];
172
175
  }
173
176
 
174
177
  inline int hamming(const uint8_t* b8) const {
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -16,6 +16,9 @@
16
16
  #ifdef __aarch64__
17
17
  // ARM compilers may produce inoptimal code for Hamming distance somewhy.
18
18
  #include <faiss/utils/hamming_distance/neon-inl.h>
19
+ #elif __AVX512F__
20
+ // offers better performance where __AVX512VPOPCNTDQ__ is supported
21
+ #include <faiss/utils/hamming_distance/avx512-inl.h>
19
22
  #elif __AVX2__
20
23
  // better versions for GenHammingComputer
21
24
  #include <faiss/utils/hamming_distance/avx2-inl.h>
@@ -55,7 +58,7 @@ SPECIALIZED_HC(64);
55
58
  /***************************************************************************
56
59
  * Dispatching function that takes a code size and a consumer object
57
60
  * the consumer object should contain a retun type t and a operation template
58
- * function f() that to be called to perform the operation.
61
+ * function f() that must be called to perform the operation.
59
62
  **************************************************************************/
60
63
 
61
64
  template <class Consumer, class... Types>
@@ -76,6 +79,7 @@ typename Consumer::T dispatch_HammingComputer(
76
79
  default:
77
80
  return consumer.template f<HammingComputerDefault>(args...);
78
81
  }
82
+ #undef DISPATCH_HC
79
83
  }
80
84
 
81
85
  } // namespace faiss
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -98,9 +98,9 @@ inline hamdis_t hamming<256>(const uint64_t* pa, const uint64_t* pb) {
98
98
 
99
99
  /* Hamming distances for multiple of 64 bits */
100
100
  inline hamdis_t hamming(const uint64_t* pa, const uint64_t* pb, size_t nwords) {
101
- const size_t nwords256 = nwords / 256;
102
- const size_t nwords128 = (nwords - nwords256 * 256) / 128;
103
- const size_t nwords64 = (nwords - nwords256 * 256 - nwords128 * 128) / 64;
101
+ const size_t nwords256 = nwords / 4;
102
+ const size_t nwords128 = (nwords % 4) / 2;
103
+ const size_t nwords64 = nwords % 2;
104
104
 
105
105
  hamdis_t h = 0;
106
106
  if (nwords256 > 0) {
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -1,5 +1,5 @@
1
- /**
2
- * Copyright (c) Facebook, Inc. and its affiliates.
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
3
  *
4
4
  * This source code is licensed under the MIT license found in the
5
5
  * LICENSE file in the root directory of this source tree.
@@ -54,6 +54,37 @@ double RandomGenerator::rand_double() {
54
54
  return mt() / double(mt.max());
55
55
  }
56
56
 
57
+ SplitMix64RandomGenerator::SplitMix64RandomGenerator(int64_t seed)
58
+ : state{static_cast<uint64_t>(seed)} {}
59
+
60
+ int SplitMix64RandomGenerator::rand_int() {
61
+ return next() & 0x7fffffff;
62
+ }
63
+
64
+ int64_t SplitMix64RandomGenerator::rand_int64() {
65
+ uint64_t value = next();
66
+ return static_cast<int64_t>(value & 0x7fffffffffffffffULL);
67
+ }
68
+
69
+ int SplitMix64RandomGenerator::rand_int(int max) {
70
+ return next() % max;
71
+ }
72
+
73
+ float SplitMix64RandomGenerator::rand_float() {
74
+ return next() / float(std::numeric_limits<uint64_t>::max());
75
+ }
76
+
77
+ double SplitMix64RandomGenerator::rand_double() {
78
+ return next() / double(std::numeric_limits<uint64_t>::max());
79
+ }
80
+
81
+ uint64_t SplitMix64RandomGenerator::next() {
82
+ uint64_t z = (state += 0x9e3779b97f4a7c15ULL);
83
+ z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ULL;
84
+ z = (z ^ (z >> 27)) * 0x94d049bb133111ebULL;
85
+ return z ^ (z >> 31);
86
+ }
87
+
57
88
  /***********************************************************************
58
89
  * Random functions in this C file only exist because Torch
59
90
  * counterparts are slow and not multi-threaded. Typical use is for
@@ -162,6 +193,18 @@ void rand_perm(int* perm, size_t n, int64_t seed) {
162
193
  }
163
194
  }
164
195
 
196
+ void rand_perm_splitmix64(int* perm, size_t n, int64_t seed) {
197
+ for (size_t i = 0; i < n; i++)
198
+ perm[i] = i;
199
+
200
+ SplitMix64RandomGenerator rng(seed);
201
+
202
+ for (size_t i = 0; i + 1 < n; i++) {
203
+ int i2 = i + rng.rand_int(n - i);
204
+ std::swap(perm[i], perm[i2]);
205
+ }
206
+ }
207
+
165
208
  void byte_rand(uint8_t* x, size_t n, int64_t seed) {
166
209
  // only try to parallelize on large enough arrays
167
210
  const size_t nblock = n < 1024 ? 1 : 1024;