faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -9,10 +9,10 @@
9
9
 
10
10
  #include <faiss/utils/utils.h>
11
11
 
12
- #include <cstdio>
13
12
  #include <cassert>
14
- #include <cstring>
15
13
  #include <cmath>
14
+ #include <cstdio>
15
+ #include <cstring>
16
16
 
17
17
  #include <sys/types.h>
18
18
 
@@ -32,46 +32,94 @@
32
32
 
33
33
  #include <faiss/impl/AuxIndexStructures.h>
34
34
  #include <faiss/impl/FaissAssert.h>
35
+ #include <faiss/impl/platform_macros.h>
35
36
  #include <faiss/utils/random.h>
36
37
 
37
-
38
-
39
38
  #ifndef FINTEGER
40
39
  #define FINTEGER long
41
40
  #endif
42
41
 
43
-
44
42
  extern "C" {
45
43
 
46
44
  /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
47
45
 
48
- int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
49
- n, FINTEGER *k, const float *alpha, const float *a,
50
- FINTEGER *lda, const float *b, FINTEGER *
51
- ldb, float *beta, float *c, FINTEGER *ldc);
46
+ int sgemm_(
47
+ const char* transa,
48
+ const char* transb,
49
+ FINTEGER* m,
50
+ FINTEGER* n,
51
+ FINTEGER* k,
52
+ const float* alpha,
53
+ const float* a,
54
+ FINTEGER* lda,
55
+ const float* b,
56
+ FINTEGER* ldb,
57
+ float* beta,
58
+ float* c,
59
+ FINTEGER* ldc);
52
60
 
53
61
  /* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */
54
62
 
55
- int sgeqrf_ (FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda,
56
- float *tau, float *work, FINTEGER *lwork, FINTEGER *info);
57
-
58
- int sorgqr_(FINTEGER *m, FINTEGER *n, FINTEGER *k, float *a,
59
- FINTEGER *lda, float *tau, float *work,
60
- FINTEGER *lwork, FINTEGER *info);
61
-
62
- int sgemv_(const char *trans, FINTEGER *m, FINTEGER *n, float *alpha,
63
- const float *a, FINTEGER *lda, const float *x, FINTEGER *incx,
64
- float *beta, float *y, FINTEGER *incy);
65
-
63
+ int sgeqrf_(
64
+ FINTEGER* m,
65
+ FINTEGER* n,
66
+ float* a,
67
+ FINTEGER* lda,
68
+ float* tau,
69
+ float* work,
70
+ FINTEGER* lwork,
71
+ FINTEGER* info);
72
+
73
+ int sorgqr_(
74
+ FINTEGER* m,
75
+ FINTEGER* n,
76
+ FINTEGER* k,
77
+ float* a,
78
+ FINTEGER* lda,
79
+ float* tau,
80
+ float* work,
81
+ FINTEGER* lwork,
82
+ FINTEGER* info);
83
+
84
+ int sgemv_(
85
+ const char* trans,
86
+ FINTEGER* m,
87
+ FINTEGER* n,
88
+ float* alpha,
89
+ const float* a,
90
+ FINTEGER* lda,
91
+ const float* x,
92
+ FINTEGER* incx,
93
+ float* beta,
94
+ float* y,
95
+ FINTEGER* incy);
66
96
  }
67
97
 
68
-
69
98
  /**************************************************
70
99
  * Get some stats about the system
71
100
  **************************************************/
72
101
 
73
102
  namespace faiss {
74
103
 
104
+ std::string get_compile_options() {
105
+ std::string options;
106
+
107
+ // this flag is set by GCC and Clang
108
+ #ifdef __OPTIMIZE__
109
+ options += "OPTIMIZE ";
110
+ #endif
111
+
112
+ #ifdef __AVX2__
113
+ options += "AVX2";
114
+ #elif defined(__aarch64__)
115
+ options += "NEON";
116
+ #else
117
+ options += "GENERIC";
118
+ #endif
119
+
120
+ return options;
121
+ }
122
+
75
123
  #ifdef _MSC_VER
76
124
  double getmillisecs() {
77
125
  LARGE_INTEGER ts;
@@ -81,73 +129,69 @@ double getmillisecs() {
81
129
 
82
130
  return (ts.QuadPart * 1e3) / freq.QuadPart;
83
131
  }
84
- #else // _MSC_VER
85
- double getmillisecs () {
132
+ #else // _MSC_VER
133
+ double getmillisecs() {
86
134
  struct timeval tv;
87
- gettimeofday (&tv, nullptr);
135
+ gettimeofday(&tv, nullptr);
88
136
  return tv.tv_sec * 1e3 + tv.tv_usec * 1e-3;
89
137
  }
90
138
  #endif // _MSC_VER
91
139
 
92
- uint64_t get_cycles () {
93
- #ifdef __x86_64__
140
+ uint64_t get_cycles() {
141
+ #ifdef __x86_64__
94
142
  uint32_t high, low;
95
- asm volatile("rdtsc \n\t"
96
- : "=a" (low),
97
- "=d" (high));
143
+ asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
98
144
  return ((uint64_t)high << 32) | (low);
99
145
  #else
100
146
  return 0;
101
147
  #endif
102
148
  }
103
149
 
104
-
105
150
  #ifdef __linux__
106
151
 
107
- size_t get_mem_usage_kb ()
108
- {
109
- int pid = getpid ();
152
+ size_t get_mem_usage_kb() {
153
+ int pid = getpid();
110
154
  char fname[256];
111
- snprintf (fname, 256, "/proc/%d/status", pid);
112
- FILE * f = fopen (fname, "r");
113
- FAISS_THROW_IF_NOT_MSG (f, "cannot open proc status file");
155
+ snprintf(fname, 256, "/proc/%d/status", pid);
156
+ FILE* f = fopen(fname, "r");
157
+ FAISS_THROW_IF_NOT_MSG(f, "cannot open proc status file");
114
158
  size_t sz = 0;
115
159
  for (;;) {
116
- char buf [256];
117
- if (!fgets (buf, 256, f)) break;
118
- if (sscanf (buf, "VmRSS: %ld kB", &sz) == 1) break;
160
+ char buf[256];
161
+ if (!fgets(buf, 256, f))
162
+ break;
163
+ if (sscanf(buf, "VmRSS: %ld kB", &sz) == 1)
164
+ break;
119
165
  }
120
- fclose (f);
166
+ fclose(f);
121
167
  return sz;
122
168
  }
123
169
 
124
170
  #else
125
171
 
126
- size_t get_mem_usage_kb ()
127
- {
128
- fprintf(stderr, "WARN: get_mem_usage_kb not implemented on current architecture\n");
172
+ size_t get_mem_usage_kb() {
173
+ fprintf(stderr,
174
+ "WARN: get_mem_usage_kb not implemented on current architecture\n");
129
175
  return 0;
130
176
  }
131
177
 
132
178
  #endif
133
179
 
134
-
135
-
136
-
137
-
138
- void reflection (const float * __restrict u,
139
- float * __restrict x,
140
- size_t n, size_t d, size_t nu)
141
- {
180
+ void reflection(
181
+ const float* __restrict u,
182
+ float* __restrict x,
183
+ size_t n,
184
+ size_t d,
185
+ size_t nu) {
142
186
  size_t i, j, l;
143
187
  for (i = 0; i < n; i++) {
144
- const float * up = u;
188
+ const float* up = u;
145
189
  for (l = 0; l < nu; l++) {
146
190
  float ip1 = 0, ip2 = 0;
147
191
 
148
- for (j = 0; j < d; j+=2) {
192
+ for (j = 0; j < d; j += 2) {
149
193
  ip1 += up[j] * x[j];
150
- ip2 += up[j+1] * x[j+1];
194
+ ip2 += up[j + 1] * x[j + 1];
151
195
  }
152
196
  float ip = 2 * (ip1 + ip2);
153
197
 
@@ -159,13 +203,11 @@ void reflection (const float * __restrict u,
159
203
  }
160
204
  }
161
205
 
162
-
163
206
  /* Reference implementation (slower) */
164
- void reflection_ref (const float * u, float * x, size_t n, size_t d, size_t nu)
165
- {
207
+ void reflection_ref(const float* u, float* x, size_t n, size_t d, size_t nu) {
166
208
  size_t i, j, l;
167
209
  for (i = 0; i < n; i++) {
168
- const float * up = u;
210
+ const float* up = u;
169
211
  for (l = 0; l < nu; l++) {
170
212
  double ip = 0;
171
213
 
@@ -182,53 +224,38 @@ void reflection_ref (const float * u, float * x, size_t n, size_t d, size_t nu)
182
224
  }
183
225
  }
184
226
 
185
-
186
-
187
-
188
-
189
-
190
227
  /***************************************************************************
191
228
  * Some matrix manipulation functions
192
229
  ***************************************************************************/
193
230
 
194
- void matrix_qr (int m, int n, float *a)
195
- {
196
- FAISS_THROW_IF_NOT (m >= n);
231
+ void matrix_qr(int m, int n, float* a) {
232
+ FAISS_THROW_IF_NOT(m >= n);
197
233
  FINTEGER mi = m, ni = n, ki = mi < ni ? mi : ni;
198
- std::vector<float> tau (ki);
234
+ std::vector<float> tau(ki);
199
235
  FINTEGER lwork = -1, info;
200
236
  float work_size;
201
237
 
202
- sgeqrf_ (&mi, &ni, a, &mi, tau.data(),
203
- &work_size, &lwork, &info);
238
+ sgeqrf_(&mi, &ni, a, &mi, tau.data(), &work_size, &lwork, &info);
204
239
  lwork = size_t(work_size);
205
- std::vector<float> work (lwork);
206
-
207
- sgeqrf_ (&mi, &ni, a, &mi,
208
- tau.data(), work.data(), &lwork, &info);
240
+ std::vector<float> work(lwork);
209
241
 
210
- sorgqr_ (&mi, &ni, &ki, a, &mi, tau.data(),
211
- work.data(), &lwork, &info);
242
+ sgeqrf_(&mi, &ni, a, &mi, tau.data(), work.data(), &lwork, &info);
212
243
 
244
+ sorgqr_(&mi, &ni, &ki, a, &mi, tau.data(), work.data(), &lwork, &info);
213
245
  }
214
246
 
215
-
216
-
217
-
218
247
  /***************************************************************************
219
248
  * Result list routines
220
249
  ***************************************************************************/
221
250
 
222
-
223
- void ranklist_handle_ties (int k, int64_t *idx, const float *dis)
224
- {
251
+ void ranklist_handle_ties(int k, int64_t* idx, const float* dis) {
225
252
  float prev_dis = -1e38;
226
253
  int prev_i = -1;
227
254
  for (int i = 0; i < k; i++) {
228
255
  if (dis[i] != prev_dis) {
229
256
  if (i > prev_i + 1) {
230
257
  // sort between prev_i and i - 1
231
- std::sort (idx + prev_i, idx + i);
258
+ std::sort(idx + prev_i, idx + i);
232
259
  }
233
260
  prev_i = i;
234
261
  prev_dis = dis[i];
@@ -236,31 +263,33 @@ void ranklist_handle_ties (int k, int64_t *idx, const float *dis)
236
263
  }
237
264
  }
238
265
 
239
- size_t merge_result_table_with (size_t n, size_t k,
240
- int64_t *I0, float *D0,
241
- const int64_t *I1, const float *D1,
242
- bool keep_min,
243
- int64_t translation)
244
- {
266
+ size_t merge_result_table_with(
267
+ size_t n,
268
+ size_t k,
269
+ int64_t* I0,
270
+ float* D0,
271
+ const int64_t* I1,
272
+ const float* D1,
273
+ bool keep_min,
274
+ int64_t translation) {
245
275
  size_t n1 = 0;
246
276
 
247
- #pragma omp parallel reduction(+:n1)
277
+ #pragma omp parallel reduction(+ : n1)
248
278
  {
249
- std::vector<int64_t> tmpI (k);
250
- std::vector<float> tmpD (k);
279
+ std::vector<int64_t> tmpI(k);
280
+ std::vector<float> tmpD(k);
251
281
 
252
282
  #pragma omp for
253
283
  for (int64_t i = 0; i < n; i++) {
254
- int64_t *lI0 = I0 + i * k;
255
- float *lD0 = D0 + i * k;
256
- const int64_t *lI1 = I1 + i * k;
257
- const float *lD1 = D1 + i * k;
284
+ int64_t* lI0 = I0 + i * k;
285
+ float* lD0 = D0 + i * k;
286
+ const int64_t* lI1 = I1 + i * k;
287
+ const float* lD1 = D1 + i * k;
258
288
  size_t r0 = 0;
259
289
  size_t r1 = 0;
260
290
 
261
291
  if (keep_min) {
262
292
  for (size_t j = 0; j < k; j++) {
263
-
264
293
  if (lI0[r0] >= 0 && lD0[r0] < lD1[r1]) {
265
294
  tmpD[j] = lD0[r0];
266
295
  tmpI[j] = lI0[r0];
@@ -291,29 +320,30 @@ size_t merge_result_table_with (size_t n, size_t k,
291
320
  }
292
321
  }
293
322
  n1 += r1;
294
- memcpy (lD0, tmpD.data(), sizeof (lD0[0]) * k);
295
- memcpy (lI0, tmpI.data(), sizeof (lI0[0]) * k);
323
+ memcpy(lD0, tmpD.data(), sizeof(lD0[0]) * k);
324
+ memcpy(lI0, tmpI.data(), sizeof(lI0[0]) * k);
296
325
  }
297
326
  }
298
327
 
299
328
  return n1;
300
329
  }
301
330
 
302
-
303
-
304
- size_t ranklist_intersection_size (size_t k1, const int64_t *v1,
305
- size_t k2, const int64_t *v2_in)
306
- {
307
- if (k2 > k1) return ranklist_intersection_size (k2, v2_in, k1, v1);
308
- int64_t *v2 = new int64_t [k2];
309
- memcpy (v2, v2_in, sizeof (int64_t) * k2);
310
- std::sort (v2, v2 + k2);
331
+ size_t ranklist_intersection_size(
332
+ size_t k1,
333
+ const int64_t* v1,
334
+ size_t k2,
335
+ const int64_t* v2_in) {
336
+ if (k2 > k1)
337
+ return ranklist_intersection_size(k2, v2_in, k1, v1);
338
+ int64_t* v2 = new int64_t[k2];
339
+ memcpy(v2, v2_in, sizeof(int64_t) * k2);
340
+ std::sort(v2, v2 + k2);
311
341
  { // de-dup v2
312
342
  int64_t prev = -1;
313
343
  size_t wp = 0;
314
344
  for (size_t i = 0; i < k2; i++) {
315
- if (v2 [i] != prev) {
316
- v2[wp++] = prev = v2 [i];
345
+ if (v2[i] != prev) {
346
+ v2[wp++] = prev = v2[i];
317
347
  }
318
348
  }
319
349
  k2 = wp;
@@ -321,195 +351,196 @@ size_t ranklist_intersection_size (size_t k1, const int64_t *v1,
321
351
  const int64_t seen_flag = int64_t{1} << 60;
322
352
  size_t count = 0;
323
353
  for (size_t i = 0; i < k1; i++) {
324
- int64_t q = v1 [i];
354
+ int64_t q = v1[i];
325
355
  size_t i0 = 0, i1 = k2;
326
356
  while (i0 + 1 < i1) {
327
357
  size_t imed = (i1 + i0) / 2;
328
- int64_t piv = v2 [imed] & ~seen_flag;
329
- if (piv <= q) i0 = imed;
330
- else i1 = imed;
358
+ int64_t piv = v2[imed] & ~seen_flag;
359
+ if (piv <= q)
360
+ i0 = imed;
361
+ else
362
+ i1 = imed;
331
363
  }
332
- if (v2 [i0] == q) {
364
+ if (v2[i0] == q) {
333
365
  count++;
334
- v2 [i0] |= seen_flag;
366
+ v2[i0] |= seen_flag;
335
367
  }
336
368
  }
337
- delete [] v2;
369
+ delete[] v2;
338
370
 
339
371
  return count;
340
372
  }
341
373
 
342
- double imbalance_factor (int k, const int *hist) {
374
+ double imbalance_factor(int k, const int* hist) {
343
375
  double tot = 0, uf = 0;
344
376
 
345
- for (int i = 0 ; i < k ; i++) {
377
+ for (int i = 0; i < k; i++) {
346
378
  tot += hist[i];
347
- uf += hist[i] * (double) hist[i];
379
+ uf += hist[i] * (double)hist[i];
348
380
  }
349
381
  uf = uf * k / (tot * tot);
350
382
 
351
383
  return uf;
352
384
  }
353
385
 
354
-
355
- double imbalance_factor (int n, int k, const int64_t *assign) {
386
+ double imbalance_factor(int n, int k, const int64_t* assign) {
356
387
  std::vector<int> hist(k, 0);
357
388
  for (int i = 0; i < n; i++) {
358
389
  hist[assign[i]]++;
359
390
  }
360
391
 
361
- return imbalance_factor (k, hist.data());
392
+ return imbalance_factor(k, hist.data());
362
393
  }
363
394
 
364
-
365
-
366
- int ivec_hist (size_t n, const int * v, int vmax, int *hist) {
367
- memset (hist, 0, sizeof(hist[0]) * vmax);
395
+ int ivec_hist(size_t n, const int* v, int vmax, int* hist) {
396
+ memset(hist, 0, sizeof(hist[0]) * vmax);
368
397
  int nout = 0;
369
398
  while (n--) {
370
- if (v[n] < 0 || v[n] >= vmax) nout++;
371
- else hist[v[n]]++;
399
+ if (v[n] < 0 || v[n] >= vmax)
400
+ nout++;
401
+ else
402
+ hist[v[n]]++;
372
403
  }
373
404
  return nout;
374
405
  }
375
406
 
376
-
377
- void bincode_hist(size_t n, size_t nbits, const uint8_t *codes, int *hist)
378
- {
379
- FAISS_THROW_IF_NOT (nbits % 8 == 0);
407
+ void bincode_hist(size_t n, size_t nbits, const uint8_t* codes, int* hist) {
408
+ FAISS_THROW_IF_NOT(nbits % 8 == 0);
380
409
  size_t d = nbits / 8;
381
410
  std::vector<int> accu(d * 256);
382
- const uint8_t *c = codes;
411
+ const uint8_t* c = codes;
383
412
  for (size_t i = 0; i < n; i++)
384
- for(int j = 0; j < d; j++)
413
+ for (int j = 0; j < d; j++)
385
414
  accu[j * 256 + *c++]++;
386
- memset (hist, 0, sizeof(*hist) * nbits);
415
+ memset(hist, 0, sizeof(*hist) * nbits);
387
416
  for (int i = 0; i < d; i++) {
388
- const int *ai = accu.data() + i * 256;
389
- int * hi = hist + i * 8;
417
+ const int* ai = accu.data() + i * 256;
418
+ int* hi = hist + i * 8;
390
419
  for (int j = 0; j < 256; j++)
391
420
  for (int k = 0; k < 8; k++)
392
421
  if ((j >> k) & 1)
393
422
  hi[k] += ai[j];
394
423
  }
395
-
396
424
  }
397
425
 
398
-
399
-
400
- size_t ivec_checksum (size_t n, const int *a)
401
- {
426
+ size_t ivec_checksum(size_t n, const int* a) {
402
427
  size_t cs = 112909;
403
- while (n--) cs = cs * 65713 + a[n] * 1686049;
428
+ while (n--)
429
+ cs = cs * 65713 + a[n] * 1686049;
404
430
  return cs;
405
431
  }
406
432
 
407
-
408
433
  namespace {
409
- struct ArgsortComparator {
410
- const float *vals;
411
- bool operator() (const size_t a, const size_t b) const {
412
- return vals[a] < vals[b];
413
- }
414
- };
434
+ struct ArgsortComparator {
435
+ const float* vals;
436
+ bool operator()(const size_t a, const size_t b) const {
437
+ return vals[a] < vals[b];
438
+ }
439
+ };
415
440
 
416
- struct SegmentS {
417
- size_t i0; // begin pointer in the permutation array
418
- size_t i1; // end
419
- size_t len() const {
420
- return i1 - i0;
421
- }
422
- };
423
-
424
- // see https://en.wikipedia.org/wiki/Merge_algorithm#Parallel_merge
425
- // extended to > 1 merge thread
426
-
427
- // merges 2 ranges that should be consecutive on the source into
428
- // the union of the two on the destination
429
- template<typename T>
430
- void parallel_merge (const T *src, T *dst,
431
- SegmentS &s1, SegmentS & s2, int nt,
432
- const ArgsortComparator & comp) {
433
- if (s2.len() > s1.len()) { // make sure that s1 larger than s2
434
- std::swap(s1, s2);
435
- }
441
+ struct SegmentS {
442
+ size_t i0; // begin pointer in the permutation array
443
+ size_t i1; // end
444
+ size_t len() const {
445
+ return i1 - i0;
446
+ }
447
+ };
448
+
449
+ // see https://en.wikipedia.org/wiki/Merge_algorithm#Parallel_merge
450
+ // extended to > 1 merge thread
451
+
452
+ // merges 2 ranges that should be consecutive on the source into
453
+ // the union of the two on the destination
454
+ template <typename T>
455
+ void parallel_merge(
456
+ const T* src,
457
+ T* dst,
458
+ SegmentS& s1,
459
+ SegmentS& s2,
460
+ int nt,
461
+ const ArgsortComparator& comp) {
462
+ if (s2.len() > s1.len()) { // make sure that s1 larger than s2
463
+ std::swap(s1, s2);
464
+ }
436
465
 
437
- // compute sub-ranges for each thread
438
- std::vector<SegmentS> s1s(nt), s2s(nt), sws(nt);
439
- s2s[0].i0 = s2.i0;
440
- s2s[nt - 1].i1 = s2.i1;
466
+ // compute sub-ranges for each thread
467
+ std::vector<SegmentS> s1s(nt), s2s(nt), sws(nt);
468
+ s2s[0].i0 = s2.i0;
469
+ s2s[nt - 1].i1 = s2.i1;
441
470
 
442
- // not sure parallel actually helps here
471
+ // not sure parallel actually helps here
443
472
  #pragma omp parallel for num_threads(nt)
444
- for (int t = 0; t < nt; t++) {
445
- s1s[t].i0 = s1.i0 + s1.len() * t / nt;
446
- s1s[t].i1 = s1.i0 + s1.len() * (t + 1) / nt;
447
-
448
- if (t + 1 < nt) {
449
- T pivot = src[s1s[t].i1];
450
- size_t i0 = s2.i0, i1 = s2.i1;
451
- while (i0 + 1 < i1) {
452
- size_t imed = (i1 + i0) / 2;
453
- if (comp (pivot, src[imed])) {i1 = imed; }
454
- else {i0 = imed; }
473
+ for (int t = 0; t < nt; t++) {
474
+ s1s[t].i0 = s1.i0 + s1.len() * t / nt;
475
+ s1s[t].i1 = s1.i0 + s1.len() * (t + 1) / nt;
476
+
477
+ if (t + 1 < nt) {
478
+ T pivot = src[s1s[t].i1];
479
+ size_t i0 = s2.i0, i1 = s2.i1;
480
+ while (i0 + 1 < i1) {
481
+ size_t imed = (i1 + i0) / 2;
482
+ if (comp(pivot, src[imed])) {
483
+ i1 = imed;
484
+ } else {
485
+ i0 = imed;
455
486
  }
456
- s2s[t].i1 = s2s[t + 1].i0 = i1;
457
487
  }
488
+ s2s[t].i1 = s2s[t + 1].i0 = i1;
458
489
  }
459
- s1.i0 = std::min(s1.i0, s2.i0);
460
- s1.i1 = std::max(s1.i1, s2.i1);
461
- s2 = s1;
462
- sws[0].i0 = s1.i0;
463
- for (int t = 0; t < nt; t++) {
464
- sws[t].i1 = sws[t].i0 + s1s[t].len() + s2s[t].len();
465
- if (t + 1 < nt) {
466
- sws[t + 1].i0 = sws[t].i1;
467
- }
490
+ }
491
+ s1.i0 = std::min(s1.i0, s2.i0);
492
+ s1.i1 = std::max(s1.i1, s2.i1);
493
+ s2 = s1;
494
+ sws[0].i0 = s1.i0;
495
+ for (int t = 0; t < nt; t++) {
496
+ sws[t].i1 = sws[t].i0 + s1s[t].len() + s2s[t].len();
497
+ if (t + 1 < nt) {
498
+ sws[t + 1].i0 = sws[t].i1;
468
499
  }
469
- assert(sws[nt - 1].i1 == s1.i1);
500
+ }
501
+ assert(sws[nt - 1].i1 == s1.i1);
470
502
 
471
- // do the actual merging
503
+ // do the actual merging
472
504
  #pragma omp parallel for num_threads(nt)
473
- for (int t = 0; t < nt; t++) {
474
- SegmentS sw = sws[t];
475
- SegmentS s1t = s1s[t];
476
- SegmentS s2t = s2s[t];
477
- if (s1t.i0 < s1t.i1 && s2t.i0 < s2t.i1) {
478
- for (;;) {
479
- // assert (sw.len() == s1t.len() + s2t.len());
480
- if (comp(src[s1t.i0], src[s2t.i0])) {
481
- dst[sw.i0++] = src[s1t.i0++];
482
- if (s1t.i0 == s1t.i1) break;
483
- } else {
484
- dst[sw.i0++] = src[s2t.i0++];
485
- if (s2t.i0 == s2t.i1) break;
486
- }
505
+ for (int t = 0; t < nt; t++) {
506
+ SegmentS sw = sws[t];
507
+ SegmentS s1t = s1s[t];
508
+ SegmentS s2t = s2s[t];
509
+ if (s1t.i0 < s1t.i1 && s2t.i0 < s2t.i1) {
510
+ for (;;) {
511
+ // assert (sw.len() == s1t.len() + s2t.len());
512
+ if (comp(src[s1t.i0], src[s2t.i0])) {
513
+ dst[sw.i0++] = src[s1t.i0++];
514
+ if (s1t.i0 == s1t.i1)
515
+ break;
516
+ } else {
517
+ dst[sw.i0++] = src[s2t.i0++];
518
+ if (s2t.i0 == s2t.i1)
519
+ break;
487
520
  }
488
521
  }
489
- if (s1t.len() > 0) {
490
- assert(s1t.len() == sw.len());
491
- memcpy(dst + sw.i0, src + s1t.i0, s1t.len() * sizeof(dst[0]));
492
- } else if (s2t.len() > 0) {
493
- assert(s2t.len() == sw.len());
494
- memcpy(dst + sw.i0, src + s2t.i0, s2t.len() * sizeof(dst[0]));
495
- }
522
+ }
523
+ if (s1t.len() > 0) {
524
+ assert(s1t.len() == sw.len());
525
+ memcpy(dst + sw.i0, src + s1t.i0, s1t.len() * sizeof(dst[0]));
526
+ } else if (s2t.len() > 0) {
527
+ assert(s2t.len() == sw.len());
528
+ memcpy(dst + sw.i0, src + s2t.i0, s2t.len() * sizeof(dst[0]));
496
529
  }
497
530
  }
531
+ }
498
532
 
499
- };
533
+ }; // namespace
500
534
 
501
- void fvec_argsort (size_t n, const float *vals,
502
- size_t *perm)
503
- {
504
- for (size_t i = 0; i < n; i++) perm[i] = i;
535
+ void fvec_argsort(size_t n, const float* vals, size_t* perm) {
536
+ for (size_t i = 0; i < n; i++)
537
+ perm[i] = i;
505
538
  ArgsortComparator comp = {vals};
506
- std::sort (perm, perm + n, comp);
539
+ std::sort(perm, perm + n, comp);
507
540
  }
508
541
 
509
- void fvec_argsort_parallel (size_t n, const float *vals,
510
- size_t *perm)
511
- {
512
- size_t * perm2 = new size_t[n];
542
+ void fvec_argsort_parallel(size_t n, const float* vals, size_t* perm) {
543
+ size_t* perm2 = new size_t[n];
513
544
  // 2 result tables, during merging, flip between them
514
545
  size_t *permB = perm2, *permA = perm;
515
546
 
@@ -519,12 +550,13 @@ void fvec_argsort_parallel (size_t n, const float *vals,
519
550
  int nseg = nt;
520
551
  while (nseg > 1) {
521
552
  nseg = (nseg + 1) / 2;
522
- std::swap (permA, permB);
553
+ std::swap(permA, permB);
523
554
  }
524
555
  }
525
556
 
526
557
  #pragma omp parallel
527
- for (size_t i = 0; i < n; i++) permA[i] = i;
558
+ for (size_t i = 0; i < n; i++)
559
+ permA[i] = i;
528
560
 
529
561
  ArgsortComparator comp = {vals};
530
562
 
@@ -536,7 +568,7 @@ void fvec_argsort_parallel (size_t n, const float *vals,
536
568
  size_t i0 = t * n / nt;
537
569
  size_t i1 = (t + 1) * n / nt;
538
570
  SegmentS seg = {i0, i1};
539
- std::sort (permA + seg.i0, permA + seg.i1, comp);
571
+ std::sort(permA + seg.i0, permA + seg.i1, comp);
540
572
  segs[t] = seg;
541
573
  }
542
574
  int prev_nested = omp_get_nested();
@@ -551,99 +583,84 @@ void fvec_argsort_parallel (size_t n, const float *vals,
551
583
  #pragma omp parallel for num_threads(nseg1)
552
584
  for (int s = 0; s < nseg; s += 2) {
553
585
  if (s + 1 == nseg) { // otherwise isolated segment
554
- memcpy(permB + segs[s].i0, permA + segs[s].i0,
586
+ memcpy(permB + segs[s].i0,
587
+ permA + segs[s].i0,
555
588
  segs[s].len() * sizeof(size_t));
556
589
  } else {
557
590
  int t0 = s * sub_nt / sub_nseg1;
558
591
  int t1 = (s + 1) * sub_nt / sub_nseg1;
559
592
  printf("merge %d %d, %d threads\n", s, s + 1, t1 - t0);
560
- parallel_merge(permA, permB, segs[s], segs[s + 1],
561
- t1 - t0, comp);
593
+ parallel_merge(
594
+ permA, permB, segs[s], segs[s + 1], t1 - t0, comp);
562
595
  }
563
596
  }
564
597
  for (int s = 0; s < nseg; s += 2)
565
598
  segs[s / 2] = segs[s];
566
599
  nseg = nseg1;
567
- std::swap (permA, permB);
600
+ std::swap(permA, permB);
568
601
  }
569
- assert (permA == perm);
602
+ assert(permA == perm);
570
603
  omp_set_nested(prev_nested);
571
- delete [] perm2;
604
+ delete[] perm2;
572
605
  }
573
606
 
574
-
575
-
576
-
577
-
578
-
579
-
580
-
581
-
582
-
583
-
584
-
585
-
586
-
587
-
588
-
589
-
590
-
591
- const float *fvecs_maybe_subsample (
592
- size_t d, size_t *n, size_t nmax, const float *x,
593
- bool verbose, int64_t seed)
594
- {
595
-
596
- if (*n <= nmax) return x; // nothing to do
607
+ const float* fvecs_maybe_subsample(
608
+ size_t d,
609
+ size_t* n,
610
+ size_t nmax,
611
+ const float* x,
612
+ bool verbose,
613
+ int64_t seed) {
614
+ if (*n <= nmax)
615
+ return x; // nothing to do
597
616
 
598
617
  size_t n2 = nmax;
599
618
  if (verbose) {
600
- printf (" Input training set too big (max size is %zd), sampling "
601
- "%zd / %zd vectors\n", nmax, n2, *n);
619
+ printf(" Input training set too big (max size is %zd), sampling "
620
+ "%zd / %zd vectors\n",
621
+ nmax,
622
+ n2,
623
+ *n);
602
624
  }
603
- std::vector<int> subset (*n);
604
- rand_perm (subset.data (), *n, seed);
605
- float *x_subset = new float[n2 * d];
625
+ std::vector<int> subset(*n);
626
+ rand_perm(subset.data(), *n, seed);
627
+ float* x_subset = new float[n2 * d];
606
628
  for (int64_t i = 0; i < n2; i++)
607
- memcpy (&x_subset[i * d],
608
- &x[subset[i] * size_t(d)],
609
- sizeof (x[0]) * d);
629
+ memcpy(&x_subset[i * d], &x[subset[i] * size_t(d)], sizeof(x[0]) * d);
610
630
  *n = n2;
611
631
  return x_subset;
612
632
  }
613
633
 
614
-
615
- void binary_to_real(size_t d, const uint8_t *x_in, float *x_out) {
634
+ void binary_to_real(size_t d, const uint8_t* x_in, float* x_out) {
616
635
  for (size_t i = 0; i < d; ++i) {
617
636
  x_out[i] = 2 * ((x_in[i >> 3] >> (i & 7)) & 1) - 1;
618
637
  }
619
638
  }
620
639
 
621
- void real_to_binary(size_t d, const float *x_in, uint8_t *x_out) {
622
- for (size_t i = 0; i < d / 8; ++i) {
623
- uint8_t b = 0;
624
- for (int j = 0; j < 8; ++j) {
625
- if (x_in[8 * i + j] > 0) {
626
- b |= (1 << j);
627
- }
640
+ void real_to_binary(size_t d, const float* x_in, uint8_t* x_out) {
641
+ for (size_t i = 0; i < d / 8; ++i) {
642
+ uint8_t b = 0;
643
+ for (int j = 0; j < 8; ++j) {
644
+ if (x_in[8 * i + j] > 0) {
645
+ b |= (1 << j);
646
+ }
647
+ }
648
+ x_out[i] = b;
628
649
  }
629
- x_out[i] = b;
630
- }
631
650
  }
632
651
 
633
-
634
652
  // from Python's stringobject.c
635
- uint64_t hash_bytes (const uint8_t *bytes, int64_t n) {
636
- const uint8_t *p = bytes;
653
+ uint64_t hash_bytes(const uint8_t* bytes, int64_t n) {
654
+ const uint8_t* p = bytes;
637
655
  uint64_t x = (uint64_t)(*p) << 7;
638
656
  int64_t len = n;
639
657
  while (--len >= 0) {
640
- x = (1000003*x) ^ *p++;
658
+ x = (1000003 * x) ^ *p++;
641
659
  }
642
660
  x ^= n;
643
661
  return x;
644
662
  }
645
663
 
646
-
647
664
  bool check_openmp() {
648
665
  omp_set_num_threads(10);
649
666
 
@@ -654,7 +671,7 @@ bool check_openmp() {
654
671
  std::vector<int> nt_per_thread(10);
655
672
  size_t sum = 0;
656
673
  bool in_parallel = true;
657
- #pragma omp parallel reduction(+: sum)
674
+ #pragma omp parallel reduction(+ : sum)
658
675
  {
659
676
  if (!omp_in_parallel()) {
660
677
  in_parallel = false;
@@ -665,7 +682,7 @@ bool check_openmp() {
665
682
 
666
683
  nt_per_thread[rank] = nt;
667
684
  #pragma omp for
668
- for(int i = 0; i < 1000 * 1000 * 10; i++) {
685
+ for (int i = 0; i < 1000 * 1000 * 10; i++) {
669
686
  sum += i;
670
687
  }
671
688
  }