faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -9,19 +9,18 @@
9
9
 
10
10
  #include <faiss/impl/lattice_Zn.h>
11
11
 
12
- #include <cstdlib>
12
+ #include <cassert>
13
13
  #include <cmath>
14
+ #include <cstdlib>
14
15
  #include <cstring>
15
- #include <cassert>
16
16
 
17
+ #include <algorithm>
17
18
  #include <queue>
18
- #include <unordered_set>
19
19
  #include <unordered_map>
20
- #include <algorithm>
20
+ #include <unordered_set>
21
21
 
22
- #include <faiss/utils/distances.h>
23
22
  #include <faiss/impl/platform_macros.h>
24
-
23
+ #include <faiss/utils/distances.h>
25
24
 
26
25
  namespace faiss {
27
26
 
@@ -35,44 +34,41 @@ inline float sqr(float x) {
35
34
  return x * x;
36
35
  }
37
36
 
38
-
39
37
  typedef std::vector<float> point_list_t;
40
38
 
41
39
  struct Comb {
42
40
  std::vector<uint64_t> tab; // Pascal's triangle
43
41
  int nmax;
44
42
 
45
- explicit Comb(int nmax): nmax(nmax) {
43
+ explicit Comb(int nmax) : nmax(nmax) {
46
44
  tab.resize(nmax * nmax, 0);
47
45
  tab[0] = 1;
48
- for(int i = 1; i < nmax; i++) {
46
+ for (int i = 1; i < nmax; i++) {
49
47
  tab[i * nmax] = 1;
50
- for(int j = 1; j <= i; j++) {
48
+ for (int j = 1; j <= i; j++) {
51
49
  tab[i * nmax + j] =
52
- tab[(i - 1) * nmax + j] +
53
- tab[(i - 1) * nmax + (j - 1)];
50
+ tab[(i - 1) * nmax + j] + tab[(i - 1) * nmax + (j - 1)];
54
51
  }
55
-
56
52
  }
57
53
  }
58
54
 
59
55
  uint64_t operator()(int n, int p) const {
60
- assert (n < nmax && p < nmax);
61
- if (p > n) return 0;
56
+ assert(n < nmax && p < nmax);
57
+ if (p > n)
58
+ return 0;
62
59
  return tab[n * nmax + p];
63
60
  }
64
61
  };
65
62
 
66
63
  Comb comb(100);
67
64
 
68
-
69
-
70
65
  // compute combinations of n integer values <= v that sum up to total (squared)
71
- point_list_t sum_of_sq (float total, int v, int n, float add = 0) {
66
+ point_list_t sum_of_sq(float total, int v, int n, float add = 0) {
72
67
  if (total < 0) {
73
68
  return point_list_t();
74
69
  } else if (n == 1) {
75
- while (sqr(v + add) > total) v--;
70
+ while (sqr(v + add) > total)
71
+ v--;
76
72
  if (sqr(v + add) == total) {
77
73
  return point_list_t(1, v + add);
78
74
  } else {
@@ -82,9 +78,9 @@ point_list_t sum_of_sq (float total, int v, int n, float add = 0) {
82
78
  point_list_t res;
83
79
  while (v >= 0) {
84
80
  point_list_t sub_points =
85
- sum_of_sq (total - sqr(v + add), v, n - 1, add);
81
+ sum_of_sq(total - sqr(v + add), v, n - 1, add);
86
82
  for (size_t i = 0; i < sub_points.size(); i += n - 1) {
87
- res.push_back (v + add);
83
+ res.push_back(v + add);
88
84
  for (int j = 0; j < n - 1; j++) {
89
85
  res.push_back(sub_points[i + j]);
90
86
  }
@@ -95,7 +91,7 @@ point_list_t sum_of_sq (float total, int v, int n, float add = 0) {
95
91
  }
96
92
  }
97
93
 
98
- int decode_comb_1 (uint64_t *n, int k1, int r) {
94
+ int decode_comb_1(uint64_t* n, int k1, int r) {
99
95
  while (comb(r, k1) > *n) {
100
96
  r--;
101
97
  }
@@ -104,10 +100,10 @@ int decode_comb_1 (uint64_t *n, int k1, int r) {
104
100
  }
105
101
 
106
102
  // optimized version for < 64 bits
107
- uint64_t repeats_encode_64 (
108
- const std::vector<Repeat> & repeats,
109
- int dim, const float *c)
110
- {
103
+ uint64_t repeats_encode_64(
104
+ const std::vector<Repeat>& repeats,
105
+ int dim,
106
+ const float* c) {
111
107
  uint64_t coded = 0;
112
108
  int nfree = dim;
113
109
  uint64_t code = 0, shift = 1;
@@ -115,15 +111,16 @@ uint64_t repeats_encode_64 (
115
111
  int rank = 0, occ = 0;
116
112
  uint64_t code_comb = 0;
117
113
  uint64_t tosee = ~coded;
118
- for(;;) {
114
+ for (;;) {
119
115
  // directly jump to next available slot.
120
116
  int i = __builtin_ctzll(tosee);
121
- tosee &= ~(uint64_t{1} << i) ;
117
+ tosee &= ~(uint64_t{1} << i);
122
118
  if (c[i] == r->val) {
123
119
  code_comb += comb(rank, occ + 1);
124
120
  occ++;
125
121
  coded |= uint64_t{1} << i;
126
- if (occ == r->n) break;
122
+ if (occ == r->n)
123
+ break;
127
124
  }
128
125
  rank++;
129
126
  }
@@ -135,11 +132,11 @@ uint64_t repeats_encode_64 (
135
132
  return code;
136
133
  }
137
134
 
138
-
139
135
  void repeats_decode_64(
140
- const std::vector<Repeat> & repeats,
141
- int dim, uint64_t code, float *c)
142
- {
136
+ const std::vector<Repeat>& repeats,
137
+ int dim,
138
+ uint64_t code,
139
+ float* c) {
143
140
  uint64_t decoded = 0;
144
141
  int nfree = dim;
145
142
  for (auto r = repeats.begin(); r != repeats.end(); ++r) {
@@ -149,9 +146,9 @@ void repeats_decode_64(
149
146
 
150
147
  int occ = 0;
151
148
  int rank = nfree;
152
- int next_rank = decode_comb_1 (&code_comb, r->n, rank);
149
+ int next_rank = decode_comb_1(&code_comb, r->n, rank);
153
150
  uint64_t tosee = ((uint64_t{1} << dim) - 1) ^ decoded;
154
- for(;;) {
151
+ for (;;) {
155
152
  int i = 63 - __builtin_clzll(tosee);
156
153
  tosee &= ~(uint64_t{1} << i);
157
154
  rank--;
@@ -159,25 +156,21 @@ void repeats_decode_64(
159
156
  decoded |= uint64_t{1} << i;
160
157
  c[i] = r->val;
161
158
  occ++;
162
- if (occ == r->n) break;
163
- next_rank = decode_comb_1 (
164
- &code_comb, r->n - occ, next_rank);
159
+ if (occ == r->n)
160
+ break;
161
+ next_rank = decode_comb_1(&code_comb, r->n - occ, next_rank);
165
162
  }
166
163
  }
167
164
  nfree -= r->n;
168
165
  }
169
-
170
166
  }
171
167
 
172
-
173
-
174
168
  } // anonymous namespace
175
169
 
176
- Repeats::Repeats (int dim, const float *c): dim(dim)
177
- {
178
- for(int i = 0; i < dim; i++) {
170
+ Repeats::Repeats(int dim, const float* c) : dim(dim) {
171
+ for (int i = 0; i < dim; i++) {
179
172
  int j = 0;
180
- for(;;) {
173
+ for (;;) {
181
174
  if (j == repeats.size()) {
182
175
  repeats.push_back(Repeat{c[i], 1});
183
176
  break;
@@ -191,9 +184,7 @@ Repeats::Repeats (int dim, const float *c): dim(dim)
191
184
  }
192
185
  }
193
186
 
194
-
195
- uint64_t Repeats::count () const
196
- {
187
+ uint64_t Repeats::count() const {
197
188
  uint64_t accu = 1;
198
189
  int remain = dim;
199
190
  for (int i = 0; i < repeats.size(); i++) {
@@ -203,13 +194,10 @@ uint64_t Repeats::count () const
203
194
  return accu;
204
195
  }
205
196
 
206
-
207
-
208
197
  // version with a bool vector that works for > 64 dim
209
- uint64_t Repeats::encode(const float *c) const
210
- {
198
+ uint64_t Repeats::encode(const float* c) const {
211
199
  if (dim < 64) {
212
- return repeats_encode_64 (repeats, dim, c);
200
+ return repeats_encode_64(repeats, dim, c);
213
201
  }
214
202
  std::vector<bool> coded(dim, false);
215
203
  int nfree = dim;
@@ -223,7 +211,8 @@ uint64_t Repeats::encode(const float *c) const
223
211
  code_comb += comb(rank, occ + 1);
224
212
  occ++;
225
213
  coded[i] = true;
226
- if (occ == r->n) break;
214
+ if (occ == r->n)
215
+ break;
227
216
  }
228
217
  rank++;
229
218
  }
@@ -236,12 +225,9 @@ uint64_t Repeats::encode(const float *c) const
236
225
  return code;
237
226
  }
238
227
 
239
-
240
-
241
- void Repeats::decode(uint64_t code, float *c) const
242
- {
228
+ void Repeats::decode(uint64_t code, float* c) const {
243
229
  if (dim < 64) {
244
- repeats_decode_64 (repeats, dim, code, c);
230
+ repeats_decode_64(repeats, dim, code, c);
245
231
  return;
246
232
  }
247
233
 
@@ -254,7 +240,7 @@ void Repeats::decode(uint64_t code, float *c) const
254
240
 
255
241
  int occ = 0;
256
242
  int rank = nfree;
257
- int next_rank = decode_comb_1 (&code_comb, r->n, rank);
243
+ int next_rank = decode_comb_1(&code_comb, r->n, rank);
258
244
  for (int i = dim - 1; i >= 0; i--) {
259
245
  if (!decoded[i]) {
260
246
  rank--;
@@ -262,65 +248,61 @@ void Repeats::decode(uint64_t code, float *c) const
262
248
  decoded[i] = true;
263
249
  c[i] = r->val;
264
250
  occ++;
265
- if (occ == r->n) break;
266
- next_rank = decode_comb_1 (
267
- &code_comb, r->n - occ, next_rank);
251
+ if (occ == r->n)
252
+ break;
253
+ next_rank =
254
+ decode_comb_1(&code_comb, r->n - occ, next_rank);
268
255
  }
269
256
  }
270
257
  }
271
258
  nfree -= r->n;
272
259
  }
273
-
274
260
  }
275
261
 
276
-
277
-
278
262
  /********************************************
279
263
  * EnumeratedVectors functions
280
264
  ********************************************/
281
265
 
282
-
283
- void EnumeratedVectors::encode_multi(size_t n, const float *c,
284
- uint64_t * codes) const
285
- {
266
+ void EnumeratedVectors::encode_multi(size_t n, const float* c, uint64_t* codes)
267
+ const {
286
268
  #pragma omp parallel if (n > 1000)
287
269
  {
288
270
  #pragma omp for
289
- for(int i = 0; i < n; i++) {
271
+ for (int i = 0; i < n; i++) {
290
272
  codes[i] = encode(c + i * dim);
291
273
  }
292
274
  }
293
275
  }
294
276
 
295
-
296
- void EnumeratedVectors::decode_multi(size_t n, const uint64_t * codes,
297
- float *c) const
298
- {
277
+ void EnumeratedVectors::decode_multi(size_t n, const uint64_t* codes, float* c)
278
+ const {
299
279
  #pragma omp parallel if (n > 1000)
300
280
  {
301
281
  #pragma omp for
302
- for(int i = 0; i < n; i++) {
282
+ for (int i = 0; i < n; i++) {
303
283
  decode(codes[i], c + i * dim);
304
284
  }
305
285
  }
306
286
  }
307
287
 
308
- void EnumeratedVectors::find_nn (
309
- size_t nc, const uint64_t * codes,
310
- size_t nq, const float *xq,
311
- int64_t *labels, float *distances)
312
- {
288
+ void EnumeratedVectors::find_nn(
289
+ size_t nc,
290
+ const uint64_t* codes,
291
+ size_t nq,
292
+ const float* xq,
293
+ int64_t* labels,
294
+ float* distances) {
313
295
  for (size_t i = 0; i < nq; i++) {
314
296
  distances[i] = -1e20;
315
297
  labels[i] = -1;
316
298
  }
317
299
 
318
300
  std::vector<float> c(dim);
319
- for(size_t i = 0; i < nc; i++) {
301
+ for (size_t i = 0; i < nc; i++) {
320
302
  uint64_t code = codes[nc];
321
303
  decode(code, c.data());
322
304
  for (size_t j = 0; j < nq; j++) {
323
- const float *x = xq + j * dim;
305
+ const float* x = xq + j * dim;
324
306
  float dis = fvec_inner_product(x, c.data(), dim);
325
307
  if (dis > distances[j]) {
326
308
  distances[j] = dis;
@@ -328,45 +310,41 @@ void EnumeratedVectors::find_nn (
328
310
  }
329
311
  }
330
312
  }
331
-
332
313
  }
333
314
 
334
-
335
315
  /**********************************************************
336
316
  * ZnSphereSearch
337
317
  **********************************************************/
338
318
 
339
-
340
- ZnSphereSearch::ZnSphereSearch(int dim, int r2): dimS(dim), r2(r2) {
319
+ ZnSphereSearch::ZnSphereSearch(int dim, int r2) : dimS(dim), r2(r2) {
341
320
  voc = sum_of_sq(r2, int(ceil(sqrt(r2)) + 1), dim);
342
321
  natom = voc.size() / dim;
343
322
  }
344
323
 
345
- float ZnSphereSearch::search(const float *x, float *c) const {
324
+ float ZnSphereSearch::search(const float* x, float* c) const {
346
325
  std::vector<float> tmp(dimS * 2);
347
326
  std::vector<int> tmp_int(dimS);
348
327
  return search(x, c, tmp.data(), tmp_int.data());
349
328
  }
350
329
 
351
- float ZnSphereSearch::search(const float *x, float *c,
352
- float *tmp, // size 2 *dim
353
- int *tmp_int, // size dim
354
- int *ibest_out
355
- ) const {
330
+ float ZnSphereSearch::search(
331
+ const float* x,
332
+ float* c,
333
+ float* tmp, // size 2 *dim
334
+ int* tmp_int, // size dim
335
+ int* ibest_out) const {
356
336
  int dim = dimS;
357
- assert (natom > 0);
358
- int *o = tmp_int;
359
- float *xabs = tmp;
360
- float *xperm = tmp + dim;
337
+ assert(natom > 0);
338
+ int* o = tmp_int;
339
+ float* xabs = tmp;
340
+ float* xperm = tmp + dim;
361
341
 
362
342
  // argsort
363
343
  for (int i = 0; i < dim; i++) {
364
344
  o[i] = i;
365
345
  xabs[i] = fabsf(x[i]);
366
346
  }
367
- std::sort(o, o + dim, [xabs](int a, int b) {
368
- return xabs[a] > xabs[b];
369
- });
347
+ std::sort(o, o + dim, [xabs](int a, int b) { return xabs[a] > xabs[b]; });
370
348
  for (int i = 0; i < dim; i++) {
371
349
  xperm[i] = xabs[o[i]];
372
350
  }
@@ -374,16 +352,16 @@ float ZnSphereSearch::search(const float *x, float *c,
374
352
  int ibest = -1;
375
353
  float dpbest = -100;
376
354
  for (int i = 0; i < natom; i++) {
377
- float dp = fvec_inner_product (voc.data() + i * dim, xperm, dim);
355
+ float dp = fvec_inner_product(voc.data() + i * dim, xperm, dim);
378
356
  if (dp > dpbest) {
379
357
  dpbest = dp;
380
358
  ibest = i;
381
359
  }
382
360
  }
383
361
  // revert sort
384
- const float *cin = voc.data() + ibest * dim;
362
+ const float* cin = voc.data() + ibest * dim;
385
363
  for (int i = 0; i < dim; i++) {
386
- c[o[i]] = copysignf (cin[i], x[o[i]]);
364
+ c[o[i]] = copysignf(cin[i], x[o[i]]);
387
365
  }
388
366
  if (ibest_out) {
389
367
  *ibest_out = ibest;
@@ -391,33 +369,32 @@ float ZnSphereSearch::search(const float *x, float *c,
391
369
  return dpbest;
392
370
  }
393
371
 
394
- void ZnSphereSearch::search_multi(int n, const float *x,
395
- float *c_out,
396
- float *dp_out) {
372
+ void ZnSphereSearch::search_multi(
373
+ int n,
374
+ const float* x,
375
+ float* c_out,
376
+ float* dp_out) {
397
377
  #pragma omp parallel if (n > 1000)
398
378
  {
399
379
  #pragma omp for
400
- for(int i = 0; i < n; i++) {
380
+ for (int i = 0; i < n; i++) {
401
381
  dp_out[i] = search(x + i * dimS, c_out + i * dimS);
402
382
  }
403
383
  }
404
384
  }
405
385
 
406
-
407
386
  /**********************************************************
408
387
  * ZnSphereCodec
409
388
  **********************************************************/
410
389
 
411
- ZnSphereCodec::ZnSphereCodec(int dim, int r2):
412
- ZnSphereSearch(dim, r2),
413
- EnumeratedVectors(dim)
414
- {
390
+ ZnSphereCodec::ZnSphereCodec(int dim, int r2)
391
+ : ZnSphereSearch(dim, r2), EnumeratedVectors(dim) {
415
392
  nv = 0;
416
393
  for (int i = 0; i < natom; i++) {
417
394
  Repeats repeats(dim, &voc[i * dim]);
418
395
  CodeSegment cs(repeats);
419
396
  cs.c0 = nv;
420
- Repeat &br = repeats.repeats.back();
397
+ Repeat& br = repeats.repeats.back();
421
398
  cs.signbits = br.val == 0 ? dim - br.n : dim;
422
399
  code_segments.push_back(cs);
423
400
  nv += repeats.count() << cs.signbits;
@@ -431,7 +408,7 @@ ZnSphereCodec::ZnSphereCodec(int dim, int r2):
431
408
  }
432
409
  }
433
410
 
434
- uint64_t ZnSphereCodec::search_and_encode(const float *x) const {
411
+ uint64_t ZnSphereCodec::search_and_encode(const float* x) const {
435
412
  std::vector<float> tmp(dim * 2);
436
413
  std::vector<int> tmp_int(dim);
437
414
  int ano; // atom number
@@ -446,30 +423,30 @@ uint64_t ZnSphereCodec::search_and_encode(const float *x) const {
446
423
  if (c[i] < 0) {
447
424
  signs |= uint64_t{1} << nnz;
448
425
  }
449
- nnz ++;
426
+ nnz++;
450
427
  }
451
428
  }
452
- const CodeSegment &cs = code_segments[ano];
429
+ const CodeSegment& cs = code_segments[ano];
453
430
  assert(nnz == cs.signbits);
454
431
  uint64_t code = cs.c0 + signs;
455
432
  code += cs.encode(cabs.data()) << cs.signbits;
456
433
  return code;
457
434
  }
458
435
 
459
- uint64_t ZnSphereCodec::encode(const float *x) const
460
- {
436
+ uint64_t ZnSphereCodec::encode(const float* x) const {
461
437
  return search_and_encode(x);
462
438
  }
463
439
 
464
-
465
- void ZnSphereCodec::decode(uint64_t code, float *c) const {
440
+ void ZnSphereCodec::decode(uint64_t code, float* c) const {
466
441
  int i0 = 0, i1 = natom;
467
442
  while (i0 + 1 < i1) {
468
443
  int imed = (i0 + i1) / 2;
469
- if (code_segments[imed].c0 <= code) i0 = imed;
470
- else i1 = imed;
444
+ if (code_segments[imed].c0 <= code)
445
+ i0 = imed;
446
+ else
447
+ i1 = imed;
471
448
  }
472
- const CodeSegment &cs = code_segments[i0];
449
+ const CodeSegment& cs = code_segments[i0];
473
450
  code -= cs.c0;
474
451
  uint64_t signs = code;
475
452
  code >>= cs.signbits;
@@ -481,42 +458,34 @@ void ZnSphereCodec::decode(uint64_t code, float *c) const {
481
458
  if (signs & (1UL << nnz)) {
482
459
  c[i] = -c[i];
483
460
  }
484
- nnz ++;
461
+ nnz++;
485
462
  }
486
463
  }
487
464
  }
488
465
 
489
-
490
466
  /**************************************************************
491
467
  * ZnSphereCodecRec
492
468
  **************************************************************/
493
469
 
494
- uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const
495
- {
470
+ uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const {
496
471
  return all_nv[ld * (r2 + 1) + r2a];
497
472
  }
498
473
 
499
-
500
- uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const
501
- {
474
+ uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const {
502
475
  return all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a];
503
476
  }
504
477
 
505
- void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum)
506
- {
478
+ void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum) {
507
479
  all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a] = cum;
508
480
  }
509
481
 
510
-
511
- ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2):
512
- EnumeratedVectors(dim), r2(r2)
513
- {
482
+ ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2)
483
+ : EnumeratedVectors(dim), r2(r2) {
514
484
  log2_dim = 0;
515
485
  while (dim > (1 << log2_dim)) {
516
486
  log2_dim++;
517
487
  }
518
- assert(dim == (1 << log2_dim) ||
519
- !"dimension must be a power of 2");
488
+ assert(dim == (1 << log2_dim) || !"dimension must be a power of 2");
520
489
 
521
490
  all_nv.resize((log2_dim + 1) * (r2 + 1));
522
491
  all_nv_cum.resize((log2_dim + 1) * (r2 + 1) * (r2 + 1));
@@ -531,7 +500,6 @@ ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2):
531
500
  }
532
501
 
533
502
  for (int ld = 1; ld <= log2_dim; ld++) {
534
-
535
503
  for (int r2sub = 0; r2sub <= r2; r2sub++) {
536
504
  uint64_t nv = 0;
537
505
  for (int r2a = 0; r2a <= r2sub; r2a++) {
@@ -559,33 +527,29 @@ ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2):
559
527
  for (int r2sub = 0; r2sub <= r2; r2sub++) {
560
528
  int ld = cache_level;
561
529
  uint64_t nvi = get_nv(ld, r2sub);
562
- std::vector<float> &cache = decode_cache[r2sub];
530
+ std::vector<float>& cache = decode_cache[r2sub];
563
531
  int dimsub = (1 << cache_level);
564
- cache.resize (nvi * dimsub);
532
+ cache.resize(nvi * dimsub);
565
533
  std::vector<float> c(dim);
566
- uint64_t code0 = get_nv_cum(cache_level + 1, r2,
567
- r2 - r2sub);
534
+ uint64_t code0 = get_nv_cum(cache_level + 1, r2, r2 - r2sub);
568
535
  for (int i = 0; i < nvi; i++) {
569
536
  decode(i + code0, c.data());
570
- memcpy(&cache[i * dimsub], c.data() + dim - dimsub,
537
+ memcpy(&cache[i * dimsub],
538
+ c.data() + dim - dimsub,
571
539
  dimsub * sizeof(*c.data()));
572
540
  }
573
541
  }
574
542
  decode_cache_ld = cache_level;
575
543
  }
576
544
 
577
- uint64_t ZnSphereCodecRec::encode(const float *c) const
578
- {
545
+ uint64_t ZnSphereCodecRec::encode(const float* c) const {
579
546
  return encode_centroid(c);
580
547
  }
581
548
 
582
-
583
-
584
- uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
585
- {
549
+ uint64_t ZnSphereCodecRec::encode_centroid(const float* c) const {
586
550
  std::vector<uint64_t> codes(dim);
587
551
  std::vector<int> norm2s(dim);
588
- for(int i = 0; i < dim; i++) {
552
+ for (int i = 0; i < dim; i++) {
589
553
  if (c[i] == 0) {
590
554
  codes[i] = 0;
591
555
  norm2s[i] = 0;
@@ -596,7 +560,7 @@ uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
596
560
  }
597
561
  }
598
562
  int dim2 = dim / 2;
599
- for(int ld = 1; ld <= log2_dim; ld++) {
563
+ for (int ld = 1; ld <= log2_dim; ld++) {
600
564
  for (int i = 0; i < dim2; i++) {
601
565
  int r2a = norm2s[2 * i];
602
566
  int r2b = norm2s[2 * i + 1];
@@ -604,10 +568,8 @@ uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
604
568
  uint64_t code_a = codes[2 * i];
605
569
  uint64_t code_b = codes[2 * i + 1];
606
570
 
607
- codes[i] =
608
- get_nv_cum(ld, r2a + r2b, r2a) +
609
- code_a * get_nv(ld - 1, r2b) +
610
- code_b;
571
+ codes[i] = get_nv_cum(ld, r2a + r2b, r2a) +
572
+ code_a * get_nv(ld - 1, r2b) + code_b;
611
573
  norm2s[i] = r2a + r2b;
612
574
  }
613
575
  dim2 /= 2;
@@ -615,23 +577,20 @@ uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
615
577
  return codes[0];
616
578
  }
617
579
 
618
-
619
-
620
- void ZnSphereCodecRec::decode(uint64_t code, float *c) const
621
- {
580
+ void ZnSphereCodecRec::decode(uint64_t code, float* c) const {
622
581
  std::vector<uint64_t> codes(dim);
623
582
  std::vector<int> norm2s(dim);
624
583
  codes[0] = code;
625
584
  norm2s[0] = r2;
626
585
 
627
586
  int dim2 = 1;
628
- for(int ld = log2_dim; ld > decode_cache_ld; ld--) {
587
+ for (int ld = log2_dim; ld > decode_cache_ld; ld--) {
629
588
  for (int i = dim2 - 1; i >= 0; i--) {
630
589
  int r2sub = norm2s[i];
631
590
  int i0 = 0, i1 = r2sub + 1;
632
591
  uint64_t codei = codes[i];
633
- const uint64_t *cum =
634
- &all_nv_cum[(ld * (r2 + 1) + r2sub) * (r2 + 1)];
592
+ const uint64_t* cum =
593
+ &all_nv_cum[(ld * (r2 + 1) + r2sub) * (r2 + 1)];
635
594
  while (i1 > i0 + 1) {
636
595
  int imed = (i0 + i1) / 2;
637
596
  if (cum[imed] <= codei)
@@ -649,13 +608,12 @@ void ZnSphereCodecRec::decode(uint64_t code, float *c) const
649
608
 
650
609
  codes[2 * i] = code_a;
651
610
  codes[2 * i + 1] = code_b;
652
-
653
611
  }
654
612
  dim2 *= 2;
655
613
  }
656
614
 
657
615
  if (decode_cache_ld == 0) {
658
- for(int i = 0; i < dim; i++) {
616
+ for (int i = 0; i < dim; i++) {
659
617
  if (norm2s[i] == 0) {
660
618
  c[i] = 0;
661
619
  } else {
@@ -666,49 +624,42 @@ void ZnSphereCodecRec::decode(uint64_t code, float *c) const
666
624
  }
667
625
  } else {
668
626
  int subdim = 1 << decode_cache_ld;
669
- assert ((dim2 * subdim) == dim);
670
-
671
- for(int i = 0; i < dim2; i++) {
627
+ assert((dim2 * subdim) == dim);
672
628
 
673
- const std::vector<float> & cache =
674
- decode_cache[norm2s[i]];
629
+ for (int i = 0; i < dim2; i++) {
630
+ const std::vector<float>& cache = decode_cache[norm2s[i]];
675
631
  assert(codes[i] < cache.size());
676
632
  memcpy(c + i * subdim,
677
633
  &cache[codes[i] * subdim],
678
- sizeof(*c)* subdim);
634
+ sizeof(*c) * subdim);
679
635
  }
680
636
  }
681
637
  }
682
638
 
683
639
  // if not use_rec, instanciate an arbitrary harmless znc_rec
684
- ZnSphereCodecAlt::ZnSphereCodecAlt (int dim, int r2):
685
- ZnSphereCodec (dim, r2),
686
- use_rec ((dim & (dim - 1)) == 0),
687
- znc_rec (use_rec ? dim : 8,
688
- use_rec ? r2 : 14)
689
- {}
690
-
691
- uint64_t ZnSphereCodecAlt::encode(const float *x) const
692
- {
640
+ ZnSphereCodecAlt::ZnSphereCodecAlt(int dim, int r2)
641
+ : ZnSphereCodec(dim, r2),
642
+ use_rec((dim & (dim - 1)) == 0),
643
+ znc_rec(use_rec ? dim : 8, use_rec ? r2 : 14) {}
644
+
645
+ uint64_t ZnSphereCodecAlt::encode(const float* x) const {
693
646
  if (!use_rec) {
694
647
  // it's ok if the vector is not normalized
695
648
  return ZnSphereCodec::encode(x);
696
649
  } else {
697
650
  // find nearest centroid
698
651
  std::vector<float> centroid(dim);
699
- search (x, centroid.data());
652
+ search(x, centroid.data());
700
653
  return znc_rec.encode(centroid.data());
701
654
  }
702
655
  }
703
656
 
704
- void ZnSphereCodecAlt::decode(uint64_t code, float *c) const
705
- {
657
+ void ZnSphereCodecAlt::decode(uint64_t code, float* c) const {
706
658
  if (!use_rec) {
707
- ZnSphereCodec::decode (code, c);
659
+ ZnSphereCodec::decode(code, c);
708
660
  } else {
709
- znc_rec.decode (code, c);
661
+ znc_rec.decode(code, c);
710
662
  }
711
663
  }
712
664
 
713
-
714
665
  } // namespace faiss