faiss 0.1.5 → 0.2.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +24 -0
  3. data/README.md +12 -0
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +6 -2
  6. data/ext/faiss/index.cpp +114 -43
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +24 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -15,43 +15,54 @@
15
15
  * always called f and thus is not passed in as a macro parameter.
16
16
  **************************************************************/
17
17
 
18
-
19
- #define READANDCHECK(ptr, n) { \
20
- size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \
21
- FAISS_THROW_IF_NOT_FMT(ret == (n), \
22
- "read error in %s: %zd != %zd (%s)", \
23
- f->name.c_str(), ret, size_t(n), strerror(errno)); \
18
+ #define READANDCHECK(ptr, n) \
19
+ { \
20
+ size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \
21
+ FAISS_THROW_IF_NOT_FMT( \
22
+ ret == (n), \
23
+ "read error in %s: %zd != %zd (%s)", \
24
+ f->name.c_str(), \
25
+ ret, \
26
+ size_t(n), \
27
+ strerror(errno)); \
24
28
  }
25
29
 
26
- #define READ1(x) READANDCHECK(&(x), 1)
30
+ #define READ1(x) READANDCHECK(&(x), 1)
27
31
 
28
32
  // will fail if we write 256G of data at once...
29
- #define READVECTOR(vec) \
30
- { \
31
- size_t size; \
32
- READANDCHECK(&size, 1); \
33
- FAISS_THROW_IF_NOT(size >= 0 && size < (uint64_t{1} << 40)); \
34
- (vec).resize(size); \
35
- READANDCHECK((vec).data(), size); \
36
- }
37
-
38
- #define READSTRING(s) { \
39
- size_t size = (s).size (); \
40
- WRITEANDCHECK (&size, 1); \
41
- WRITEANDCHECK ((s).c_str(), size); \
33
+ #define READVECTOR(vec) \
34
+ { \
35
+ size_t size; \
36
+ READANDCHECK(&size, 1); \
37
+ FAISS_THROW_IF_NOT(size >= 0 && size < (uint64_t{1} << 40)); \
38
+ (vec).resize(size); \
39
+ READANDCHECK((vec).data(), size); \
40
+ }
41
+
42
+ #define READSTRING(s) \
43
+ { \
44
+ size_t size = (s).size(); \
45
+ WRITEANDCHECK(&size, 1); \
46
+ WRITEANDCHECK((s).c_str(), size); \
42
47
  }
43
48
 
44
- #define WRITEANDCHECK(ptr, n) { \
45
- size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \
46
- FAISS_THROW_IF_NOT_FMT(ret == (n), \
47
- "write error in %s: %zd != %zd (%s)", \
48
- f->name.c_str(), ret, size_t(n), strerror(errno)); \
49
+ #define WRITEANDCHECK(ptr, n) \
50
+ { \
51
+ size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \
52
+ FAISS_THROW_IF_NOT_FMT( \
53
+ ret == (n), \
54
+ "write error in %s: %zd != %zd (%s)", \
55
+ f->name.c_str(), \
56
+ ret, \
57
+ size_t(n), \
58
+ strerror(errno)); \
49
59
  }
50
60
 
51
61
  #define WRITE1(x) WRITEANDCHECK(&(x), 1)
52
62
 
53
- #define WRITEVECTOR(vec) { \
54
- size_t size = (vec).size (); \
55
- WRITEANDCHECK (&size, 1); \
56
- WRITEANDCHECK ((vec).data (), size); \
63
+ #define WRITEVECTOR(vec) \
64
+ { \
65
+ size_t size = (vec).size(); \
66
+ WRITEANDCHECK(&size, 1); \
67
+ WRITEANDCHECK((vec).data(), size); \
57
68
  }
@@ -9,19 +9,18 @@
9
9
 
10
10
  #include <faiss/impl/lattice_Zn.h>
11
11
 
12
- #include <cstdlib>
12
+ #include <cassert>
13
13
  #include <cmath>
14
+ #include <cstdlib>
14
15
  #include <cstring>
15
- #include <cassert>
16
16
 
17
+ #include <algorithm>
17
18
  #include <queue>
18
- #include <unordered_set>
19
19
  #include <unordered_map>
20
- #include <algorithm>
20
+ #include <unordered_set>
21
21
 
22
- #include <faiss/utils/distances.h>
23
22
  #include <faiss/impl/platform_macros.h>
24
-
23
+ #include <faiss/utils/distances.h>
25
24
 
26
25
  namespace faiss {
27
26
 
@@ -35,44 +34,41 @@ inline float sqr(float x) {
35
34
  return x * x;
36
35
  }
37
36
 
38
-
39
37
  typedef std::vector<float> point_list_t;
40
38
 
41
39
  struct Comb {
42
40
  std::vector<uint64_t> tab; // Pascal's triangle
43
41
  int nmax;
44
42
 
45
- explicit Comb(int nmax): nmax(nmax) {
43
+ explicit Comb(int nmax) : nmax(nmax) {
46
44
  tab.resize(nmax * nmax, 0);
47
45
  tab[0] = 1;
48
- for(int i = 1; i < nmax; i++) {
46
+ for (int i = 1; i < nmax; i++) {
49
47
  tab[i * nmax] = 1;
50
- for(int j = 1; j <= i; j++) {
48
+ for (int j = 1; j <= i; j++) {
51
49
  tab[i * nmax + j] =
52
- tab[(i - 1) * nmax + j] +
53
- tab[(i - 1) * nmax + (j - 1)];
50
+ tab[(i - 1) * nmax + j] + tab[(i - 1) * nmax + (j - 1)];
54
51
  }
55
-
56
52
  }
57
53
  }
58
54
 
59
55
  uint64_t operator()(int n, int p) const {
60
- assert (n < nmax && p < nmax);
61
- if (p > n) return 0;
56
+ assert(n < nmax && p < nmax);
57
+ if (p > n)
58
+ return 0;
62
59
  return tab[n * nmax + p];
63
60
  }
64
61
  };
65
62
 
66
63
  Comb comb(100);
67
64
 
68
-
69
-
70
65
  // compute combinations of n integer values <= v that sum up to total (squared)
71
- point_list_t sum_of_sq (float total, int v, int n, float add = 0) {
66
+ point_list_t sum_of_sq(float total, int v, int n, float add = 0) {
72
67
  if (total < 0) {
73
68
  return point_list_t();
74
69
  } else if (n == 1) {
75
- while (sqr(v + add) > total) v--;
70
+ while (sqr(v + add) > total)
71
+ v--;
76
72
  if (sqr(v + add) == total) {
77
73
  return point_list_t(1, v + add);
78
74
  } else {
@@ -82,9 +78,9 @@ point_list_t sum_of_sq (float total, int v, int n, float add = 0) {
82
78
  point_list_t res;
83
79
  while (v >= 0) {
84
80
  point_list_t sub_points =
85
- sum_of_sq (total - sqr(v + add), v, n - 1, add);
81
+ sum_of_sq(total - sqr(v + add), v, n - 1, add);
86
82
  for (size_t i = 0; i < sub_points.size(); i += n - 1) {
87
- res.push_back (v + add);
83
+ res.push_back(v + add);
88
84
  for (int j = 0; j < n - 1; j++) {
89
85
  res.push_back(sub_points[i + j]);
90
86
  }
@@ -95,7 +91,7 @@ point_list_t sum_of_sq (float total, int v, int n, float add = 0) {
95
91
  }
96
92
  }
97
93
 
98
- int decode_comb_1 (uint64_t *n, int k1, int r) {
94
+ int decode_comb_1(uint64_t* n, int k1, int r) {
99
95
  while (comb(r, k1) > *n) {
100
96
  r--;
101
97
  }
@@ -104,10 +100,10 @@ int decode_comb_1 (uint64_t *n, int k1, int r) {
104
100
  }
105
101
 
106
102
  // optimized version for < 64 bits
107
- uint64_t repeats_encode_64 (
108
- const std::vector<Repeat> & repeats,
109
- int dim, const float *c)
110
- {
103
+ uint64_t repeats_encode_64(
104
+ const std::vector<Repeat>& repeats,
105
+ int dim,
106
+ const float* c) {
111
107
  uint64_t coded = 0;
112
108
  int nfree = dim;
113
109
  uint64_t code = 0, shift = 1;
@@ -115,15 +111,16 @@ uint64_t repeats_encode_64 (
115
111
  int rank = 0, occ = 0;
116
112
  uint64_t code_comb = 0;
117
113
  uint64_t tosee = ~coded;
118
- for(;;) {
114
+ for (;;) {
119
115
  // directly jump to next available slot.
120
116
  int i = __builtin_ctzll(tosee);
121
- tosee &= ~(uint64_t{1} << i) ;
117
+ tosee &= ~(uint64_t{1} << i);
122
118
  if (c[i] == r->val) {
123
119
  code_comb += comb(rank, occ + 1);
124
120
  occ++;
125
121
  coded |= uint64_t{1} << i;
126
- if (occ == r->n) break;
122
+ if (occ == r->n)
123
+ break;
127
124
  }
128
125
  rank++;
129
126
  }
@@ -135,11 +132,11 @@ uint64_t repeats_encode_64 (
135
132
  return code;
136
133
  }
137
134
 
138
-
139
135
  void repeats_decode_64(
140
- const std::vector<Repeat> & repeats,
141
- int dim, uint64_t code, float *c)
142
- {
136
+ const std::vector<Repeat>& repeats,
137
+ int dim,
138
+ uint64_t code,
139
+ float* c) {
143
140
  uint64_t decoded = 0;
144
141
  int nfree = dim;
145
142
  for (auto r = repeats.begin(); r != repeats.end(); ++r) {
@@ -149,9 +146,9 @@ void repeats_decode_64(
149
146
 
150
147
  int occ = 0;
151
148
  int rank = nfree;
152
- int next_rank = decode_comb_1 (&code_comb, r->n, rank);
149
+ int next_rank = decode_comb_1(&code_comb, r->n, rank);
153
150
  uint64_t tosee = ((uint64_t{1} << dim) - 1) ^ decoded;
154
- for(;;) {
151
+ for (;;) {
155
152
  int i = 63 - __builtin_clzll(tosee);
156
153
  tosee &= ~(uint64_t{1} << i);
157
154
  rank--;
@@ -159,25 +156,21 @@ void repeats_decode_64(
159
156
  decoded |= uint64_t{1} << i;
160
157
  c[i] = r->val;
161
158
  occ++;
162
- if (occ == r->n) break;
163
- next_rank = decode_comb_1 (
164
- &code_comb, r->n - occ, next_rank);
159
+ if (occ == r->n)
160
+ break;
161
+ next_rank = decode_comb_1(&code_comb, r->n - occ, next_rank);
165
162
  }
166
163
  }
167
164
  nfree -= r->n;
168
165
  }
169
-
170
166
  }
171
167
 
172
-
173
-
174
168
  } // anonymous namespace
175
169
 
176
- Repeats::Repeats (int dim, const float *c): dim(dim)
177
- {
178
- for(int i = 0; i < dim; i++) {
170
+ Repeats::Repeats(int dim, const float* c) : dim(dim) {
171
+ for (int i = 0; i < dim; i++) {
179
172
  int j = 0;
180
- for(;;) {
173
+ for (;;) {
181
174
  if (j == repeats.size()) {
182
175
  repeats.push_back(Repeat{c[i], 1});
183
176
  break;
@@ -191,9 +184,7 @@ Repeats::Repeats (int dim, const float *c): dim(dim)
191
184
  }
192
185
  }
193
186
 
194
-
195
- uint64_t Repeats::count () const
196
- {
187
+ uint64_t Repeats::count() const {
197
188
  uint64_t accu = 1;
198
189
  int remain = dim;
199
190
  for (int i = 0; i < repeats.size(); i++) {
@@ -203,13 +194,10 @@ uint64_t Repeats::count () const
203
194
  return accu;
204
195
  }
205
196
 
206
-
207
-
208
197
  // version with a bool vector that works for > 64 dim
209
- uint64_t Repeats::encode(const float *c) const
210
- {
198
+ uint64_t Repeats::encode(const float* c) const {
211
199
  if (dim < 64) {
212
- return repeats_encode_64 (repeats, dim, c);
200
+ return repeats_encode_64(repeats, dim, c);
213
201
  }
214
202
  std::vector<bool> coded(dim, false);
215
203
  int nfree = dim;
@@ -223,7 +211,8 @@ uint64_t Repeats::encode(const float *c) const
223
211
  code_comb += comb(rank, occ + 1);
224
212
  occ++;
225
213
  coded[i] = true;
226
- if (occ == r->n) break;
214
+ if (occ == r->n)
215
+ break;
227
216
  }
228
217
  rank++;
229
218
  }
@@ -236,12 +225,9 @@ uint64_t Repeats::encode(const float *c) const
236
225
  return code;
237
226
  }
238
227
 
239
-
240
-
241
- void Repeats::decode(uint64_t code, float *c) const
242
- {
228
+ void Repeats::decode(uint64_t code, float* c) const {
243
229
  if (dim < 64) {
244
- repeats_decode_64 (repeats, dim, code, c);
230
+ repeats_decode_64(repeats, dim, code, c);
245
231
  return;
246
232
  }
247
233
 
@@ -254,7 +240,7 @@ void Repeats::decode(uint64_t code, float *c) const
254
240
 
255
241
  int occ = 0;
256
242
  int rank = nfree;
257
- int next_rank = decode_comb_1 (&code_comb, r->n, rank);
243
+ int next_rank = decode_comb_1(&code_comb, r->n, rank);
258
244
  for (int i = dim - 1; i >= 0; i--) {
259
245
  if (!decoded[i]) {
260
246
  rank--;
@@ -262,65 +248,61 @@ void Repeats::decode(uint64_t code, float *c) const
262
248
  decoded[i] = true;
263
249
  c[i] = r->val;
264
250
  occ++;
265
- if (occ == r->n) break;
266
- next_rank = decode_comb_1 (
267
- &code_comb, r->n - occ, next_rank);
251
+ if (occ == r->n)
252
+ break;
253
+ next_rank =
254
+ decode_comb_1(&code_comb, r->n - occ, next_rank);
268
255
  }
269
256
  }
270
257
  }
271
258
  nfree -= r->n;
272
259
  }
273
-
274
260
  }
275
261
 
276
-
277
-
278
262
  /********************************************
279
263
  * EnumeratedVectors functions
280
264
  ********************************************/
281
265
 
282
-
283
- void EnumeratedVectors::encode_multi(size_t n, const float *c,
284
- uint64_t * codes) const
285
- {
266
+ void EnumeratedVectors::encode_multi(size_t n, const float* c, uint64_t* codes)
267
+ const {
286
268
  #pragma omp parallel if (n > 1000)
287
269
  {
288
270
  #pragma omp for
289
- for(int i = 0; i < n; i++) {
271
+ for (int i = 0; i < n; i++) {
290
272
  codes[i] = encode(c + i * dim);
291
273
  }
292
274
  }
293
275
  }
294
276
 
295
-
296
- void EnumeratedVectors::decode_multi(size_t n, const uint64_t * codes,
297
- float *c) const
298
- {
277
+ void EnumeratedVectors::decode_multi(size_t n, const uint64_t* codes, float* c)
278
+ const {
299
279
  #pragma omp parallel if (n > 1000)
300
280
  {
301
281
  #pragma omp for
302
- for(int i = 0; i < n; i++) {
282
+ for (int i = 0; i < n; i++) {
303
283
  decode(codes[i], c + i * dim);
304
284
  }
305
285
  }
306
286
  }
307
287
 
308
- void EnumeratedVectors::find_nn (
309
- size_t nc, const uint64_t * codes,
310
- size_t nq, const float *xq,
311
- int64_t *labels, float *distances)
312
- {
288
+ void EnumeratedVectors::find_nn(
289
+ size_t nc,
290
+ const uint64_t* codes,
291
+ size_t nq,
292
+ const float* xq,
293
+ int64_t* labels,
294
+ float* distances) {
313
295
  for (size_t i = 0; i < nq; i++) {
314
296
  distances[i] = -1e20;
315
297
  labels[i] = -1;
316
298
  }
317
299
 
318
300
  std::vector<float> c(dim);
319
- for(size_t i = 0; i < nc; i++) {
301
+ for (size_t i = 0; i < nc; i++) {
320
302
  uint64_t code = codes[nc];
321
303
  decode(code, c.data());
322
304
  for (size_t j = 0; j < nq; j++) {
323
- const float *x = xq + j * dim;
305
+ const float* x = xq + j * dim;
324
306
  float dis = fvec_inner_product(x, c.data(), dim);
325
307
  if (dis > distances[j]) {
326
308
  distances[j] = dis;
@@ -328,45 +310,41 @@ void EnumeratedVectors::find_nn (
328
310
  }
329
311
  }
330
312
  }
331
-
332
313
  }
333
314
 
334
-
335
315
  /**********************************************************
336
316
  * ZnSphereSearch
337
317
  **********************************************************/
338
318
 
339
-
340
- ZnSphereSearch::ZnSphereSearch(int dim, int r2): dimS(dim), r2(r2) {
319
+ ZnSphereSearch::ZnSphereSearch(int dim, int r2) : dimS(dim), r2(r2) {
341
320
  voc = sum_of_sq(r2, int(ceil(sqrt(r2)) + 1), dim);
342
321
  natom = voc.size() / dim;
343
322
  }
344
323
 
345
- float ZnSphereSearch::search(const float *x, float *c) const {
324
+ float ZnSphereSearch::search(const float* x, float* c) const {
346
325
  std::vector<float> tmp(dimS * 2);
347
326
  std::vector<int> tmp_int(dimS);
348
327
  return search(x, c, tmp.data(), tmp_int.data());
349
328
  }
350
329
 
351
- float ZnSphereSearch::search(const float *x, float *c,
352
- float *tmp, // size 2 *dim
353
- int *tmp_int, // size dim
354
- int *ibest_out
355
- ) const {
330
+ float ZnSphereSearch::search(
331
+ const float* x,
332
+ float* c,
333
+ float* tmp, // size 2 *dim
334
+ int* tmp_int, // size dim
335
+ int* ibest_out) const {
356
336
  int dim = dimS;
357
- assert (natom > 0);
358
- int *o = tmp_int;
359
- float *xabs = tmp;
360
- float *xperm = tmp + dim;
337
+ assert(natom > 0);
338
+ int* o = tmp_int;
339
+ float* xabs = tmp;
340
+ float* xperm = tmp + dim;
361
341
 
362
342
  // argsort
363
343
  for (int i = 0; i < dim; i++) {
364
344
  o[i] = i;
365
345
  xabs[i] = fabsf(x[i]);
366
346
  }
367
- std::sort(o, o + dim, [xabs](int a, int b) {
368
- return xabs[a] > xabs[b];
369
- });
347
+ std::sort(o, o + dim, [xabs](int a, int b) { return xabs[a] > xabs[b]; });
370
348
  for (int i = 0; i < dim; i++) {
371
349
  xperm[i] = xabs[o[i]];
372
350
  }
@@ -374,16 +352,16 @@ float ZnSphereSearch::search(const float *x, float *c,
374
352
  int ibest = -1;
375
353
  float dpbest = -100;
376
354
  for (int i = 0; i < natom; i++) {
377
- float dp = fvec_inner_product (voc.data() + i * dim, xperm, dim);
355
+ float dp = fvec_inner_product(voc.data() + i * dim, xperm, dim);
378
356
  if (dp > dpbest) {
379
357
  dpbest = dp;
380
358
  ibest = i;
381
359
  }
382
360
  }
383
361
  // revert sort
384
- const float *cin = voc.data() + ibest * dim;
362
+ const float* cin = voc.data() + ibest * dim;
385
363
  for (int i = 0; i < dim; i++) {
386
- c[o[i]] = copysignf (cin[i], x[o[i]]);
364
+ c[o[i]] = copysignf(cin[i], x[o[i]]);
387
365
  }
388
366
  if (ibest_out) {
389
367
  *ibest_out = ibest;
@@ -391,33 +369,32 @@ float ZnSphereSearch::search(const float *x, float *c,
391
369
  return dpbest;
392
370
  }
393
371
 
394
- void ZnSphereSearch::search_multi(int n, const float *x,
395
- float *c_out,
396
- float *dp_out) {
372
+ void ZnSphereSearch::search_multi(
373
+ int n,
374
+ const float* x,
375
+ float* c_out,
376
+ float* dp_out) {
397
377
  #pragma omp parallel if (n > 1000)
398
378
  {
399
379
  #pragma omp for
400
- for(int i = 0; i < n; i++) {
380
+ for (int i = 0; i < n; i++) {
401
381
  dp_out[i] = search(x + i * dimS, c_out + i * dimS);
402
382
  }
403
383
  }
404
384
  }
405
385
 
406
-
407
386
  /**********************************************************
408
387
  * ZnSphereCodec
409
388
  **********************************************************/
410
389
 
411
- ZnSphereCodec::ZnSphereCodec(int dim, int r2):
412
- ZnSphereSearch(dim, r2),
413
- EnumeratedVectors(dim)
414
- {
390
+ ZnSphereCodec::ZnSphereCodec(int dim, int r2)
391
+ : ZnSphereSearch(dim, r2), EnumeratedVectors(dim) {
415
392
  nv = 0;
416
393
  for (int i = 0; i < natom; i++) {
417
394
  Repeats repeats(dim, &voc[i * dim]);
418
395
  CodeSegment cs(repeats);
419
396
  cs.c0 = nv;
420
- Repeat &br = repeats.repeats.back();
397
+ Repeat& br = repeats.repeats.back();
421
398
  cs.signbits = br.val == 0 ? dim - br.n : dim;
422
399
  code_segments.push_back(cs);
423
400
  nv += repeats.count() << cs.signbits;
@@ -431,7 +408,7 @@ ZnSphereCodec::ZnSphereCodec(int dim, int r2):
431
408
  }
432
409
  }
433
410
 
434
- uint64_t ZnSphereCodec::search_and_encode(const float *x) const {
411
+ uint64_t ZnSphereCodec::search_and_encode(const float* x) const {
435
412
  std::vector<float> tmp(dim * 2);
436
413
  std::vector<int> tmp_int(dim);
437
414
  int ano; // atom number
@@ -446,30 +423,30 @@ uint64_t ZnSphereCodec::search_and_encode(const float *x) const {
446
423
  if (c[i] < 0) {
447
424
  signs |= uint64_t{1} << nnz;
448
425
  }
449
- nnz ++;
426
+ nnz++;
450
427
  }
451
428
  }
452
- const CodeSegment &cs = code_segments[ano];
429
+ const CodeSegment& cs = code_segments[ano];
453
430
  assert(nnz == cs.signbits);
454
431
  uint64_t code = cs.c0 + signs;
455
432
  code += cs.encode(cabs.data()) << cs.signbits;
456
433
  return code;
457
434
  }
458
435
 
459
- uint64_t ZnSphereCodec::encode(const float *x) const
460
- {
436
+ uint64_t ZnSphereCodec::encode(const float* x) const {
461
437
  return search_and_encode(x);
462
438
  }
463
439
 
464
-
465
- void ZnSphereCodec::decode(uint64_t code, float *c) const {
440
+ void ZnSphereCodec::decode(uint64_t code, float* c) const {
466
441
  int i0 = 0, i1 = natom;
467
442
  while (i0 + 1 < i1) {
468
443
  int imed = (i0 + i1) / 2;
469
- if (code_segments[imed].c0 <= code) i0 = imed;
470
- else i1 = imed;
444
+ if (code_segments[imed].c0 <= code)
445
+ i0 = imed;
446
+ else
447
+ i1 = imed;
471
448
  }
472
- const CodeSegment &cs = code_segments[i0];
449
+ const CodeSegment& cs = code_segments[i0];
473
450
  code -= cs.c0;
474
451
  uint64_t signs = code;
475
452
  code >>= cs.signbits;
@@ -481,42 +458,34 @@ void ZnSphereCodec::decode(uint64_t code, float *c) const {
481
458
  if (signs & (1UL << nnz)) {
482
459
  c[i] = -c[i];
483
460
  }
484
- nnz ++;
461
+ nnz++;
485
462
  }
486
463
  }
487
464
  }
488
465
 
489
-
490
466
  /**************************************************************
491
467
  * ZnSphereCodecRec
492
468
  **************************************************************/
493
469
 
494
- uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const
495
- {
470
+ uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const {
496
471
  return all_nv[ld * (r2 + 1) + r2a];
497
472
  }
498
473
 
499
-
500
- uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const
501
- {
474
+ uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const {
502
475
  return all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a];
503
476
  }
504
477
 
505
- void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum)
506
- {
478
+ void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum) {
507
479
  all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a] = cum;
508
480
  }
509
481
 
510
-
511
- ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2):
512
- EnumeratedVectors(dim), r2(r2)
513
- {
482
+ ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2)
483
+ : EnumeratedVectors(dim), r2(r2) {
514
484
  log2_dim = 0;
515
485
  while (dim > (1 << log2_dim)) {
516
486
  log2_dim++;
517
487
  }
518
- assert(dim == (1 << log2_dim) ||
519
- !"dimension must be a power of 2");
488
+ assert(dim == (1 << log2_dim) || !"dimension must be a power of 2");
520
489
 
521
490
  all_nv.resize((log2_dim + 1) * (r2 + 1));
522
491
  all_nv_cum.resize((log2_dim + 1) * (r2 + 1) * (r2 + 1));
@@ -531,7 +500,6 @@ ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2):
531
500
  }
532
501
 
533
502
  for (int ld = 1; ld <= log2_dim; ld++) {
534
-
535
503
  for (int r2sub = 0; r2sub <= r2; r2sub++) {
536
504
  uint64_t nv = 0;
537
505
  for (int r2a = 0; r2a <= r2sub; r2a++) {
@@ -559,33 +527,29 @@ ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2):
559
527
  for (int r2sub = 0; r2sub <= r2; r2sub++) {
560
528
  int ld = cache_level;
561
529
  uint64_t nvi = get_nv(ld, r2sub);
562
- std::vector<float> &cache = decode_cache[r2sub];
530
+ std::vector<float>& cache = decode_cache[r2sub];
563
531
  int dimsub = (1 << cache_level);
564
- cache.resize (nvi * dimsub);
532
+ cache.resize(nvi * dimsub);
565
533
  std::vector<float> c(dim);
566
- uint64_t code0 = get_nv_cum(cache_level + 1, r2,
567
- r2 - r2sub);
534
+ uint64_t code0 = get_nv_cum(cache_level + 1, r2, r2 - r2sub);
568
535
  for (int i = 0; i < nvi; i++) {
569
536
  decode(i + code0, c.data());
570
- memcpy(&cache[i * dimsub], c.data() + dim - dimsub,
537
+ memcpy(&cache[i * dimsub],
538
+ c.data() + dim - dimsub,
571
539
  dimsub * sizeof(*c.data()));
572
540
  }
573
541
  }
574
542
  decode_cache_ld = cache_level;
575
543
  }
576
544
 
577
- uint64_t ZnSphereCodecRec::encode(const float *c) const
578
- {
545
+ uint64_t ZnSphereCodecRec::encode(const float* c) const {
579
546
  return encode_centroid(c);
580
547
  }
581
548
 
582
-
583
-
584
- uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
585
- {
549
+ uint64_t ZnSphereCodecRec::encode_centroid(const float* c) const {
586
550
  std::vector<uint64_t> codes(dim);
587
551
  std::vector<int> norm2s(dim);
588
- for(int i = 0; i < dim; i++) {
552
+ for (int i = 0; i < dim; i++) {
589
553
  if (c[i] == 0) {
590
554
  codes[i] = 0;
591
555
  norm2s[i] = 0;
@@ -596,7 +560,7 @@ uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
596
560
  }
597
561
  }
598
562
  int dim2 = dim / 2;
599
- for(int ld = 1; ld <= log2_dim; ld++) {
563
+ for (int ld = 1; ld <= log2_dim; ld++) {
600
564
  for (int i = 0; i < dim2; i++) {
601
565
  int r2a = norm2s[2 * i];
602
566
  int r2b = norm2s[2 * i + 1];
@@ -604,10 +568,8 @@ uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
604
568
  uint64_t code_a = codes[2 * i];
605
569
  uint64_t code_b = codes[2 * i + 1];
606
570
 
607
- codes[i] =
608
- get_nv_cum(ld, r2a + r2b, r2a) +
609
- code_a * get_nv(ld - 1, r2b) +
610
- code_b;
571
+ codes[i] = get_nv_cum(ld, r2a + r2b, r2a) +
572
+ code_a * get_nv(ld - 1, r2b) + code_b;
611
573
  norm2s[i] = r2a + r2b;
612
574
  }
613
575
  dim2 /= 2;
@@ -615,23 +577,20 @@ uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
615
577
  return codes[0];
616
578
  }
617
579
 
618
-
619
-
620
- void ZnSphereCodecRec::decode(uint64_t code, float *c) const
621
- {
580
+ void ZnSphereCodecRec::decode(uint64_t code, float* c) const {
622
581
  std::vector<uint64_t> codes(dim);
623
582
  std::vector<int> norm2s(dim);
624
583
  codes[0] = code;
625
584
  norm2s[0] = r2;
626
585
 
627
586
  int dim2 = 1;
628
- for(int ld = log2_dim; ld > decode_cache_ld; ld--) {
587
+ for (int ld = log2_dim; ld > decode_cache_ld; ld--) {
629
588
  for (int i = dim2 - 1; i >= 0; i--) {
630
589
  int r2sub = norm2s[i];
631
590
  int i0 = 0, i1 = r2sub + 1;
632
591
  uint64_t codei = codes[i];
633
- const uint64_t *cum =
634
- &all_nv_cum[(ld * (r2 + 1) + r2sub) * (r2 + 1)];
592
+ const uint64_t* cum =
593
+ &all_nv_cum[(ld * (r2 + 1) + r2sub) * (r2 + 1)];
635
594
  while (i1 > i0 + 1) {
636
595
  int imed = (i0 + i1) / 2;
637
596
  if (cum[imed] <= codei)
@@ -649,13 +608,12 @@ void ZnSphereCodecRec::decode(uint64_t code, float *c) const
649
608
 
650
609
  codes[2 * i] = code_a;
651
610
  codes[2 * i + 1] = code_b;
652
-
653
611
  }
654
612
  dim2 *= 2;
655
613
  }
656
614
 
657
615
  if (decode_cache_ld == 0) {
658
- for(int i = 0; i < dim; i++) {
616
+ for (int i = 0; i < dim; i++) {
659
617
  if (norm2s[i] == 0) {
660
618
  c[i] = 0;
661
619
  } else {
@@ -666,49 +624,42 @@ void ZnSphereCodecRec::decode(uint64_t code, float *c) const
666
624
  }
667
625
  } else {
668
626
  int subdim = 1 << decode_cache_ld;
669
- assert ((dim2 * subdim) == dim);
670
-
671
- for(int i = 0; i < dim2; i++) {
627
+ assert((dim2 * subdim) == dim);
672
628
 
673
- const std::vector<float> & cache =
674
- decode_cache[norm2s[i]];
629
+ for (int i = 0; i < dim2; i++) {
630
+ const std::vector<float>& cache = decode_cache[norm2s[i]];
675
631
  assert(codes[i] < cache.size());
676
632
  memcpy(c + i * subdim,
677
633
  &cache[codes[i] * subdim],
678
- sizeof(*c)* subdim);
634
+ sizeof(*c) * subdim);
679
635
  }
680
636
  }
681
637
  }
682
638
 
683
639
  // if not use_rec, instanciate an arbitrary harmless znc_rec
684
- ZnSphereCodecAlt::ZnSphereCodecAlt (int dim, int r2):
685
- ZnSphereCodec (dim, r2),
686
- use_rec ((dim & (dim - 1)) == 0),
687
- znc_rec (use_rec ? dim : 8,
688
- use_rec ? r2 : 14)
689
- {}
690
-
691
- uint64_t ZnSphereCodecAlt::encode(const float *x) const
692
- {
640
+ ZnSphereCodecAlt::ZnSphereCodecAlt(int dim, int r2)
641
+ : ZnSphereCodec(dim, r2),
642
+ use_rec((dim & (dim - 1)) == 0),
643
+ znc_rec(use_rec ? dim : 8, use_rec ? r2 : 14) {}
644
+
645
+ uint64_t ZnSphereCodecAlt::encode(const float* x) const {
693
646
  if (!use_rec) {
694
647
  // it's ok if the vector is not normalized
695
648
  return ZnSphereCodec::encode(x);
696
649
  } else {
697
650
  // find nearest centroid
698
651
  std::vector<float> centroid(dim);
699
- search (x, centroid.data());
652
+ search(x, centroid.data());
700
653
  return znc_rec.encode(centroid.data());
701
654
  }
702
655
  }
703
656
 
704
- void ZnSphereCodecAlt::decode(uint64_t code, float *c) const
705
- {
657
+ void ZnSphereCodecAlt::decode(uint64_t code, float* c) const {
706
658
  if (!use_rec) {
707
- ZnSphereCodec::decode (code, c);
659
+ ZnSphereCodec::decode(code, c);
708
660
  } else {
709
- znc_rec.decode (code, c);
661
+ znc_rec.decode(code, c);
710
662
  }
711
663
  }
712
664
 
713
-
714
665
  } // namespace faiss