faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -10,20 +10,19 @@
10
10
  #include <faiss/VectorTransform.h>
11
11
 
12
12
  #include <cinttypes>
13
- #include <cstdio>
14
13
  #include <cmath>
14
+ #include <cstdio>
15
15
  #include <cstring>
16
16
  #include <memory>
17
17
 
18
+ #include <faiss/IndexPQ.h>
19
+ #include <faiss/impl/FaissAssert.h>
18
20
  #include <faiss/utils/distances.h>
19
21
  #include <faiss/utils/random.h>
20
22
  #include <faiss/utils/utils.h>
21
- #include <faiss/impl/FaissAssert.h>
22
- #include <faiss/IndexPQ.h>
23
23
 
24
24
  using namespace faiss;
25
25
 
26
-
27
26
  extern "C" {
28
27
 
29
28
  // this is to keep the clang syntax checker happy
@@ -31,134 +30,183 @@ extern "C" {
31
30
  #define FINTEGER int
32
31
  #endif
33
32
 
34
-
35
33
  /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
36
34
 
37
- int sgemm_ (
38
- const char *transa, const char *transb, FINTEGER *m, FINTEGER *
39
- n, FINTEGER *k, const float *alpha, const float *a,
40
- FINTEGER *lda, const float *b,
41
- FINTEGER *ldb, float *beta,
42
- float *c, FINTEGER *ldc);
43
-
44
- int dgemm_ (
45
- const char *transa, const char *transb, FINTEGER *m, FINTEGER *
46
- n, FINTEGER *k, const double *alpha, const double *a,
47
- FINTEGER *lda, const double *b,
48
- FINTEGER *ldb, double *beta,
49
- double *c, FINTEGER *ldc);
50
-
51
- int ssyrk_ (
52
- const char *uplo, const char *trans, FINTEGER *n, FINTEGER *k,
53
- float *alpha, float *a, FINTEGER *lda,
54
- float *beta, float *c, FINTEGER *ldc);
35
+ int sgemm_(
36
+ const char* transa,
37
+ const char* transb,
38
+ FINTEGER* m,
39
+ FINTEGER* n,
40
+ FINTEGER* k,
41
+ const float* alpha,
42
+ const float* a,
43
+ FINTEGER* lda,
44
+ const float* b,
45
+ FINTEGER* ldb,
46
+ float* beta,
47
+ float* c,
48
+ FINTEGER* ldc);
49
+
50
+ int dgemm_(
51
+ const char* transa,
52
+ const char* transb,
53
+ FINTEGER* m,
54
+ FINTEGER* n,
55
+ FINTEGER* k,
56
+ const double* alpha,
57
+ const double* a,
58
+ FINTEGER* lda,
59
+ const double* b,
60
+ FINTEGER* ldb,
61
+ double* beta,
62
+ double* c,
63
+ FINTEGER* ldc);
64
+
65
+ int ssyrk_(
66
+ const char* uplo,
67
+ const char* trans,
68
+ FINTEGER* n,
69
+ FINTEGER* k,
70
+ float* alpha,
71
+ float* a,
72
+ FINTEGER* lda,
73
+ float* beta,
74
+ float* c,
75
+ FINTEGER* ldc);
55
76
 
56
77
  /* Lapack functions from http://www.netlib.org/clapack/old/single/ */
57
78
 
58
- int ssyev_ (
59
- const char *jobz, const char *uplo, FINTEGER *n, float *a,
60
- FINTEGER *lda, float *w, float *work, FINTEGER *lwork,
61
- FINTEGER *info);
62
-
63
- int dsyev_ (
64
- const char *jobz, const char *uplo, FINTEGER *n, double *a,
65
- FINTEGER *lda, double *w, double *work, FINTEGER *lwork,
66
- FINTEGER *info);
79
+ int ssyev_(
80
+ const char* jobz,
81
+ const char* uplo,
82
+ FINTEGER* n,
83
+ float* a,
84
+ FINTEGER* lda,
85
+ float* w,
86
+ float* work,
87
+ FINTEGER* lwork,
88
+ FINTEGER* info);
89
+
90
+ int dsyev_(
91
+ const char* jobz,
92
+ const char* uplo,
93
+ FINTEGER* n,
94
+ double* a,
95
+ FINTEGER* lda,
96
+ double* w,
97
+ double* work,
98
+ FINTEGER* lwork,
99
+ FINTEGER* info);
67
100
 
68
101
  int sgesvd_(
69
- const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n,
70
- float *a, FINTEGER *lda, float *s, float *u, FINTEGER *ldu, float *vt,
71
- FINTEGER *ldvt, float *work, FINTEGER *lwork, FINTEGER *info);
72
-
102
+ const char* jobu,
103
+ const char* jobvt,
104
+ FINTEGER* m,
105
+ FINTEGER* n,
106
+ float* a,
107
+ FINTEGER* lda,
108
+ float* s,
109
+ float* u,
110
+ FINTEGER* ldu,
111
+ float* vt,
112
+ FINTEGER* ldvt,
113
+ float* work,
114
+ FINTEGER* lwork,
115
+ FINTEGER* info);
73
116
 
74
117
  int dgesvd_(
75
- const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n,
76
- double *a, FINTEGER *lda, double *s, double *u, FINTEGER *ldu, double *vt,
77
- FINTEGER *ldvt, double *work, FINTEGER *lwork, FINTEGER *info);
78
-
118
+ const char* jobu,
119
+ const char* jobvt,
120
+ FINTEGER* m,
121
+ FINTEGER* n,
122
+ double* a,
123
+ FINTEGER* lda,
124
+ double* s,
125
+ double* u,
126
+ FINTEGER* ldu,
127
+ double* vt,
128
+ FINTEGER* ldvt,
129
+ double* work,
130
+ FINTEGER* lwork,
131
+ FINTEGER* info);
79
132
  }
80
133
 
81
134
  /*********************************************
82
135
  * VectorTransform
83
136
  *********************************************/
84
137
 
85
-
86
-
87
- float * VectorTransform::apply (Index::idx_t n, const float * x) const
88
- {
89
- float * xt = new float[n * d_out];
90
- apply_noalloc (n, x, xt);
138
+ float* VectorTransform::apply(Index::idx_t n, const float* x) const {
139
+ float* xt = new float[n * d_out];
140
+ apply_noalloc(n, x, xt);
91
141
  return xt;
92
142
  }
93
143
 
94
-
95
- void VectorTransform::train (idx_t, const float *) {
144
+ void VectorTransform::train(idx_t, const float*) {
96
145
  // does nothing by default
97
146
  }
98
147
 
99
-
100
- void VectorTransform::reverse_transform (
101
- idx_t , const float *,
102
- float *) const
103
- {
104
- FAISS_THROW_MSG ("reverse transform not implemented");
148
+ void VectorTransform::reverse_transform(idx_t, const float*, float*) const {
149
+ FAISS_THROW_MSG("reverse transform not implemented");
105
150
  }
106
151
 
107
-
108
-
109
-
110
152
  /*********************************************
111
153
  * LinearTransform
112
154
  *********************************************/
113
155
 
114
156
  /// both d_in > d_out and d_out < d_in are supported
115
- LinearTransform::LinearTransform (int d_in, int d_out,
116
- bool have_bias):
117
- VectorTransform (d_in, d_out), have_bias (have_bias),
118
- is_orthonormal (false), verbose (false)
119
- {
157
+ LinearTransform::LinearTransform(int d_in, int d_out, bool have_bias)
158
+ : VectorTransform(d_in, d_out),
159
+ have_bias(have_bias),
160
+ is_orthonormal(false),
161
+ verbose(false) {
120
162
  is_trained = false; // will be trained when A and b are initialized
121
163
  }
122
164
 
123
- void LinearTransform::apply_noalloc (Index::idx_t n, const float * x,
124
- float * xt) const
125
- {
165
+ void LinearTransform::apply_noalloc(Index::idx_t n, const float* x, float* xt)
166
+ const {
126
167
  FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
127
168
 
128
169
  float c_factor;
129
170
  if (have_bias) {
130
- FAISS_THROW_IF_NOT_MSG (b.size() == d_out, "Bias not initialized");
131
- float * xi = xt;
171
+ FAISS_THROW_IF_NOT_MSG(b.size() == d_out, "Bias not initialized");
172
+ float* xi = xt;
132
173
  for (int i = 0; i < n; i++)
133
- for(int j = 0; j < d_out; j++)
174
+ for (int j = 0; j < d_out; j++)
134
175
  *xi++ = b[j];
135
176
  c_factor = 1.0;
136
177
  } else {
137
178
  c_factor = 0.0;
138
179
  }
139
180
 
140
- FAISS_THROW_IF_NOT_MSG (A.size() == d_out * d_in,
141
- "Transformation matrix not initialized");
181
+ FAISS_THROW_IF_NOT_MSG(
182
+ A.size() == d_out * d_in, "Transformation matrix not initialized");
142
183
 
143
184
  float one = 1;
144
185
  FINTEGER nbiti = d_out, ni = n, di = d_in;
145
- sgemm_ ("Transposed", "Not transposed",
146
- &nbiti, &ni, &di,
147
- &one, A.data(), &di, x, &di, &c_factor, xt, &nbiti);
148
-
186
+ sgemm_("Transposed",
187
+ "Not transposed",
188
+ &nbiti,
189
+ &ni,
190
+ &di,
191
+ &one,
192
+ A.data(),
193
+ &di,
194
+ x,
195
+ &di,
196
+ &c_factor,
197
+ xt,
198
+ &nbiti);
149
199
  }
150
200
 
151
-
152
- void LinearTransform::transform_transpose (idx_t n, const float * y,
153
- float *x) const
154
- {
201
+ void LinearTransform::transform_transpose(idx_t n, const float* y, float* x)
202
+ const {
155
203
  if (have_bias) { // allocate buffer to store bias-corrected data
156
- float *y_new = new float [n * d_out];
157
- const float *yr = y;
158
- float *yw = y_new;
204
+ float* y_new = new float[n * d_out];
205
+ const float* yr = y;
206
+ float* yw = y_new;
159
207
  for (idx_t i = 0; i < n; i++) {
160
208
  for (int j = 0; j < d_out; j++) {
161
- *yw++ = *yr++ - b [j];
209
+ *yw++ = *yr++ - b[j];
162
210
  }
163
211
  }
164
212
  y = y_new;
@@ -167,15 +215,26 @@ void LinearTransform::transform_transpose (idx_t n, const float * y,
167
215
  {
168
216
  FINTEGER dii = d_in, doi = d_out, ni = n;
169
217
  float one = 1.0, zero = 0.0;
170
- sgemm_ ("Not", "Not", &dii, &ni, &doi,
171
- &one, A.data (), &dii, y, &doi, &zero, x, &dii);
218
+ sgemm_("Not",
219
+ "Not",
220
+ &dii,
221
+ &ni,
222
+ &doi,
223
+ &one,
224
+ A.data(),
225
+ &dii,
226
+ y,
227
+ &doi,
228
+ &zero,
229
+ x,
230
+ &dii);
172
231
  }
173
232
 
174
- if (have_bias) delete [] y;
233
+ if (have_bias)
234
+ delete[] y;
175
235
  }
176
236
 
177
- void LinearTransform::set_is_orthonormal ()
178
- {
237
+ void LinearTransform::set_is_orthonormal() {
179
238
  if (d_out > d_in) {
180
239
  // not clear what we should do in this case
181
240
  is_orthonormal = false;
@@ -193,44 +252,53 @@ void LinearTransform::set_is_orthonormal ()
193
252
  FINTEGER dii = d_in, doi = d_out;
194
253
  float one = 1.0, zero = 0.0;
195
254
 
196
- sgemm_ ("Transposed", "Not", &doi, &doi, &dii,
197
- &one, A.data (), &dii,
198
- A.data(), &dii,
199
- &zero, ATA.data(), &doi);
255
+ sgemm_("Transposed",
256
+ "Not",
257
+ &doi,
258
+ &doi,
259
+ &dii,
260
+ &one,
261
+ A.data(),
262
+ &dii,
263
+ A.data(),
264
+ &dii,
265
+ &zero,
266
+ ATA.data(),
267
+ &doi);
200
268
 
201
269
  is_orthonormal = true;
202
270
  for (long i = 0; i < d_out; i++) {
203
271
  for (long j = 0; j < d_out; j++) {
204
272
  float v = ATA[i + j * d_out];
205
- if (i == j) v-= 1;
273
+ if (i == j)
274
+ v -= 1;
206
275
  if (fabs(v) > eps) {
207
276
  is_orthonormal = false;
208
277
  }
209
278
  }
210
279
  }
211
280
  }
212
-
213
281
  }
214
282
 
215
-
216
- void LinearTransform::reverse_transform (idx_t n, const float * xt,
217
- float *x) const
218
- {
283
+ void LinearTransform::reverse_transform(idx_t n, const float* xt, float* x)
284
+ const {
219
285
  if (is_orthonormal) {
220
- transform_transpose (n, xt, x);
286
+ transform_transpose(n, xt, x);
221
287
  } else {
222
- FAISS_THROW_MSG ("reverse transform not implemented for non-orthonormal matrices");
288
+ FAISS_THROW_MSG(
289
+ "reverse transform not implemented for non-orthonormal matrices");
223
290
  }
224
291
  }
225
292
 
226
-
227
- void LinearTransform::print_if_verbose (
228
- const char*name, const std::vector<double> &mat,
229
- int n, int d) const
230
- {
231
- if (!verbose) return;
293
+ void LinearTransform::print_if_verbose(
294
+ const char* name,
295
+ const std::vector<double>& mat,
296
+ int n,
297
+ int d) const {
298
+ if (!verbose)
299
+ return;
232
300
  printf("matrix %s: %d*%d [\n", name, n, d);
233
- FAISS_THROW_IF_NOT (mat.size() >= n * d);
301
+ FAISS_THROW_IF_NOT(mat.size() >= n * d);
234
302
  for (int i = 0; i < n; i++) {
235
303
  for (int j = 0; j < d; j++) {
236
304
  printf("%10.5g ", mat[i * d + j]);
@@ -244,24 +312,22 @@ void LinearTransform::print_if_verbose (
244
312
  * RandomRotationMatrix
245
313
  *********************************************/
246
314
 
247
- void RandomRotationMatrix::init (int seed)
248
- {
249
-
250
- if(d_out <= d_in) {
251
- A.resize (d_out * d_in);
252
- float *q = A.data();
315
+ void RandomRotationMatrix::init(int seed) {
316
+ if (d_out <= d_in) {
317
+ A.resize(d_out * d_in);
318
+ float* q = A.data();
253
319
  float_randn(q, d_out * d_in, seed);
254
320
  matrix_qr(d_in, d_out, q);
255
321
  } else {
256
322
  // use tight-frame transformation
257
- A.resize (d_out * d_out);
258
- float *q = A.data();
323
+ A.resize(d_out * d_out);
324
+ float* q = A.data();
259
325
  float_randn(q, d_out * d_out, seed);
260
326
  matrix_qr(d_out, d_out, q);
261
327
  // remove columns
262
328
  int i, j;
263
329
  for (i = 0; i < d_out; i++) {
264
- for(j = 0; j < d_in; j++) {
330
+ for (j = 0; j < d_in; j++) {
265
331
  q[i * d_in + j] = q[i * d_out + j];
266
332
  }
267
333
  }
@@ -271,247 +337,281 @@ void RandomRotationMatrix::init (int seed)
271
337
  is_trained = true;
272
338
  }
273
339
 
274
- void RandomRotationMatrix::train (Index::idx_t /*n*/, const float * /*x*/)
275
- {
340
+ void RandomRotationMatrix::train(Index::idx_t /*n*/, const float* /*x*/) {
276
341
  // initialize with some arbitrary seed
277
- init (12345);
342
+ init(12345);
278
343
  }
279
344
 
280
-
281
345
  /*********************************************
282
346
  * PCAMatrix
283
347
  *********************************************/
284
348
 
285
- PCAMatrix::PCAMatrix (int d_in, int d_out,
286
- float eigen_power, bool random_rotation):
287
- LinearTransform(d_in, d_out, true),
288
- eigen_power(eigen_power), random_rotation(random_rotation)
289
- {
349
+ PCAMatrix::PCAMatrix(
350
+ int d_in,
351
+ int d_out,
352
+ float eigen_power,
353
+ bool random_rotation)
354
+ : LinearTransform(d_in, d_out, true),
355
+ eigen_power(eigen_power),
356
+ random_rotation(random_rotation) {
290
357
  is_trained = false;
291
358
  max_points_per_d = 1000;
292
359
  balanced_bins = 0;
360
+ epsilon = 0;
293
361
  }
294
362
 
295
-
296
363
  namespace {
297
364
 
298
365
  /// Compute the eigenvalue decomposition of symmetric matrix cov,
299
366
  /// dimensions d_in-by-d_in. Output eigenvectors in cov.
300
367
 
301
- void eig(size_t d_in, double *cov, double *eigenvalues, int verbose)
302
- {
368
+ void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
303
369
  { // compute eigenvalues and vectors
304
370
  FINTEGER info = 0, lwork = -1, di = d_in;
305
371
  double workq;
306
372
 
307
- dsyev_ ("Vectors as well", "Upper",
308
- &di, cov, &di, eigenvalues, &workq, &lwork, &info);
373
+ dsyev_("Vectors as well",
374
+ "Upper",
375
+ &di,
376
+ cov,
377
+ &di,
378
+ eigenvalues,
379
+ &workq,
380
+ &lwork,
381
+ &info);
309
382
  lwork = FINTEGER(workq);
310
- double *work = new double[lwork];
383
+ double* work = new double[lwork];
311
384
 
312
- dsyev_ ("Vectors as well", "Upper",
313
- &di, cov, &di, eigenvalues, work, &lwork, &info);
385
+ dsyev_("Vectors as well",
386
+ "Upper",
387
+ &di,
388
+ cov,
389
+ &di,
390
+ eigenvalues,
391
+ work,
392
+ &lwork,
393
+ &info);
314
394
 
315
- delete [] work;
395
+ delete[] work;
316
396
 
317
397
  if (info != 0) {
318
- fprintf (stderr, "WARN ssyev info returns %d, "
319
- "a very bad PCA matrix is learnt\n",
320
- int(info));
398
+ fprintf(stderr,
399
+ "WARN ssyev info returns %d, "
400
+ "a very bad PCA matrix is learnt\n",
401
+ int(info));
321
402
  // do not throw exception, as the matrix could still be useful
322
403
  }
323
404
 
324
-
325
- if(verbose && d_in <= 10) {
405
+ if (verbose && d_in <= 10) {
326
406
  printf("info=%ld new eigvals=[", long(info));
327
- for(int j = 0; j < d_in; j++) printf("%g ", eigenvalues[j]);
407
+ for (int j = 0; j < d_in; j++)
408
+ printf("%g ", eigenvalues[j]);
328
409
  printf("]\n");
329
410
 
330
- double *ci = cov;
411
+ double* ci = cov;
331
412
  printf("eigenvecs=\n");
332
- for(int i = 0; i < d_in; i++) {
333
- for(int j = 0; j < d_in; j++)
413
+ for (int i = 0; i < d_in; i++) {
414
+ for (int j = 0; j < d_in; j++)
334
415
  printf("%10.4g ", *ci++);
335
416
  printf("\n");
336
417
  }
337
418
  }
338
-
339
419
  }
340
420
 
341
421
  // revert order of eigenvectors & values
342
422
 
343
- for(int i = 0; i < d_in / 2; i++) {
344
-
423
+ for (int i = 0; i < d_in / 2; i++) {
345
424
  std::swap(eigenvalues[i], eigenvalues[d_in - 1 - i]);
346
- double *v1 = cov + i * d_in;
347
- double *v2 = cov + (d_in - 1 - i) * d_in;
348
- for(int j = 0; j < d_in; j++)
425
+ double* v1 = cov + i * d_in;
426
+ double* v2 = cov + (d_in - 1 - i) * d_in;
427
+ for (int j = 0; j < d_in; j++)
349
428
  std::swap(v1[j], v2[j]);
350
429
  }
351
-
352
430
  }
353
431
 
432
+ } // namespace
354
433
 
355
- }
356
-
357
- void PCAMatrix::train (Index::idx_t n, const float *x)
358
- {
359
- const float * x_in = x;
434
+ void PCAMatrix::train(Index::idx_t n, const float* x) {
435
+ const float* x_in = x;
360
436
 
361
- x = fvecs_maybe_subsample (d_in, (size_t*)&n,
362
- max_points_per_d * d_in, x, verbose);
437
+ x = fvecs_maybe_subsample(
438
+ d_in, (size_t*)&n, max_points_per_d * d_in, x, verbose);
363
439
 
364
- ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
440
+ ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
365
441
 
366
442
  // compute mean
367
- mean.clear(); mean.resize(d_in, 0.0);
443
+ mean.clear();
444
+ mean.resize(d_in, 0.0);
368
445
  if (have_bias) { // we may want to skip the bias
369
- const float *xi = x;
446
+ const float* xi = x;
370
447
  for (int i = 0; i < n; i++) {
371
- for(int j = 0; j < d_in; j++)
448
+ for (int j = 0; j < d_in; j++)
372
449
  mean[j] += *xi++;
373
450
  }
374
- for(int j = 0; j < d_in; j++)
451
+ for (int j = 0; j < d_in; j++)
375
452
  mean[j] /= n;
376
453
  }
377
- if(verbose) {
454
+ if (verbose) {
378
455
  printf("mean=[");
379
- for(int j = 0; j < d_in; j++) printf("%g ", mean[j]);
456
+ for (int j = 0; j < d_in; j++)
457
+ printf("%g ", mean[j]);
380
458
  printf("]\n");
381
459
  }
382
460
 
383
- if(n >= d_in) {
461
+ if (n >= d_in) {
384
462
  // compute covariance matrix, store it in PCA matrix
385
463
  PCAMat.resize(d_in * d_in);
386
- float * cov = PCAMat.data();
464
+ float* cov = PCAMat.data();
387
465
  { // initialize with mean * mean^T term
388
- float *ci = cov;
389
- for(int i = 0; i < d_in; i++) {
390
- for(int j = 0; j < d_in; j++)
391
- *ci++ = - n * mean[i] * mean[j];
466
+ float* ci = cov;
467
+ for (int i = 0; i < d_in; i++) {
468
+ for (int j = 0; j < d_in; j++)
469
+ *ci++ = -n * mean[i] * mean[j];
392
470
  }
393
471
  }
394
472
  {
395
473
  FINTEGER di = d_in, ni = n;
396
474
  float one = 1.0;
397
- ssyrk_ ("Up", "Non transposed",
398
- &di, &ni, &one, (float*)x, &di, &one, cov, &di);
399
-
475
+ ssyrk_("Up",
476
+ "Non transposed",
477
+ &di,
478
+ &ni,
479
+ &one,
480
+ (float*)x,
481
+ &di,
482
+ &one,
483
+ cov,
484
+ &di);
400
485
  }
401
- if(verbose && d_in <= 10) {
402
- float *ci = cov;
486
+ if (verbose && d_in <= 10) {
487
+ float* ci = cov;
403
488
  printf("cov=\n");
404
- for(int i = 0; i < d_in; i++) {
405
- for(int j = 0; j < d_in; j++)
489
+ for (int i = 0; i < d_in; i++) {
490
+ for (int j = 0; j < d_in; j++)
406
491
  printf("%10g ", *ci++);
407
492
  printf("\n");
408
493
  }
409
494
  }
410
495
 
411
- std::vector<double> covd (d_in * d_in);
412
- for (size_t i = 0; i < d_in * d_in; i++) covd [i] = cov [i];
496
+ std::vector<double> covd(d_in * d_in);
497
+ for (size_t i = 0; i < d_in * d_in; i++)
498
+ covd[i] = cov[i];
413
499
 
414
- std::vector<double> eigenvaluesd (d_in);
500
+ std::vector<double> eigenvaluesd(d_in);
415
501
 
416
- eig (d_in, covd.data (), eigenvaluesd.data (), verbose);
502
+ eig(d_in, covd.data(), eigenvaluesd.data(), verbose);
417
503
 
418
- for (size_t i = 0; i < d_in * d_in; i++) PCAMat [i] = covd [i];
419
- eigenvalues.resize (d_in);
504
+ for (size_t i = 0; i < d_in * d_in; i++)
505
+ PCAMat[i] = covd[i];
506
+ eigenvalues.resize(d_in);
420
507
 
421
508
  for (size_t i = 0; i < d_in; i++)
422
- eigenvalues [i] = eigenvaluesd [i];
423
-
509
+ eigenvalues[i] = eigenvaluesd[i];
424
510
 
425
511
  } else {
426
-
427
- std::vector<float> xc (n * d_in);
512
+ std::vector<float> xc(n * d_in);
428
513
 
429
514
  for (size_t i = 0; i < n; i++)
430
- for(size_t j = 0; j < d_in; j++)
431
- xc [i * d_in + j] = x [i * d_in + j] - mean[j];
515
+ for (size_t j = 0; j < d_in; j++)
516
+ xc[i * d_in + j] = x[i * d_in + j] - mean[j];
432
517
 
433
518
  // compute Gram matrix
434
- std::vector<float> gram (n * n);
519
+ std::vector<float> gram(n * n);
435
520
  {
436
521
  FINTEGER di = d_in, ni = n;
437
522
  float one = 1.0, zero = 0.0;
438
- ssyrk_ ("Up", "Transposed",
439
- &ni, &di, &one, xc.data(), &di, &zero, gram.data(), &ni);
523
+ ssyrk_("Up",
524
+ "Transposed",
525
+ &ni,
526
+ &di,
527
+ &one,
528
+ xc.data(),
529
+ &di,
530
+ &zero,
531
+ gram.data(),
532
+ &ni);
440
533
  }
441
534
 
442
- if(verbose && d_in <= 10) {
443
- float *ci = gram.data();
535
+ if (verbose && d_in <= 10) {
536
+ float* ci = gram.data();
444
537
  printf("gram=\n");
445
- for(int i = 0; i < n; i++) {
446
- for(int j = 0; j < n; j++)
538
+ for (int i = 0; i < n; i++) {
539
+ for (int j = 0; j < n; j++)
447
540
  printf("%10g ", *ci++);
448
541
  printf("\n");
449
542
  }
450
543
  }
451
544
 
452
- std::vector<double> gramd (n * n);
545
+ std::vector<double> gramd(n * n);
453
546
  for (size_t i = 0; i < n * n; i++)
454
- gramd [i] = gram [i];
547
+ gramd[i] = gram[i];
455
548
 
456
- std::vector<double> eigenvaluesd (n);
549
+ std::vector<double> eigenvaluesd(n);
457
550
 
458
551
  // eig will fill in only the n first eigenvals
459
552
 
460
- eig (n, gramd.data (), eigenvaluesd.data (), verbose);
553
+ eig(n, gramd.data(), eigenvaluesd.data(), verbose);
461
554
 
462
555
  PCAMat.resize(d_in * n);
463
556
 
464
557
  for (size_t i = 0; i < n * n; i++)
465
- gram [i] = gramd [i];
558
+ gram[i] = gramd[i];
466
559
 
467
- eigenvalues.resize (d_in);
560
+ eigenvalues.resize(d_in);
468
561
  // fill in only the n first ones
469
562
  for (size_t i = 0; i < n; i++)
470
- eigenvalues [i] = eigenvaluesd [i];
563
+ eigenvalues[i] = eigenvaluesd[i];
471
564
 
472
565
  { // compute PCAMat = x' * v
473
566
  FINTEGER di = d_in, ni = n;
474
567
  float one = 1.0;
475
568
 
476
- sgemm_ ("Non", "Non Trans",
477
- &di, &ni, &ni,
478
- &one, xc.data(), &di, gram.data(), &ni,
479
- &one, PCAMat.data(), &di);
569
+ sgemm_("Non",
570
+ "Non Trans",
571
+ &di,
572
+ &ni,
573
+ &ni,
574
+ &one,
575
+ xc.data(),
576
+ &di,
577
+ gram.data(),
578
+ &ni,
579
+ &one,
580
+ PCAMat.data(),
581
+ &di);
480
582
  }
481
583
 
482
- if(verbose && d_in <= 10) {
483
- float *ci = PCAMat.data();
584
+ if (verbose && d_in <= 10) {
585
+ float* ci = PCAMat.data();
484
586
  printf("PCAMat=\n");
485
- for(int i = 0; i < n; i++) {
486
- for(int j = 0; j < d_in; j++)
587
+ for (int i = 0; i < n; i++) {
588
+ for (int j = 0; j < d_in; j++)
487
589
  printf("%10g ", *ci++);
488
590
  printf("\n");
489
591
  }
490
592
  }
491
- fvec_renorm_L2 (d_in, n, PCAMat.data());
492
-
593
+ fvec_renorm_L2(d_in, n, PCAMat.data());
493
594
  }
494
595
 
495
596
  prepare_Ab();
496
597
  is_trained = true;
497
598
  }
498
599
 
499
- void PCAMatrix::copy_from (const PCAMatrix & other)
500
- {
501
- FAISS_THROW_IF_NOT (other.is_trained);
600
+ void PCAMatrix::copy_from(const PCAMatrix& other) {
601
+ FAISS_THROW_IF_NOT(other.is_trained);
502
602
  mean = other.mean;
503
603
  eigenvalues = other.eigenvalues;
504
604
  PCAMat = other.PCAMat;
505
- prepare_Ab ();
605
+ prepare_Ab();
506
606
  is_trained = true;
507
607
  }
508
608
 
509
- void PCAMatrix::prepare_Ab ()
510
- {
511
- FAISS_THROW_IF_NOT_FMT (
609
+ void PCAMatrix::prepare_Ab() {
610
+ FAISS_THROW_IF_NOT_FMT(
512
611
  d_out * d_in <= PCAMat.size(),
513
612
  "PCA matrix cannot output %d dimensions from %d ",
514
- d_out, d_in);
613
+ d_out,
614
+ d_in);
515
615
 
516
616
  if (!random_rotation) {
517
617
  A = PCAMat;
@@ -519,23 +619,23 @@ void PCAMatrix::prepare_Ab ()
519
619
 
520
620
  // first scale the components
521
621
  if (eigen_power != 0) {
522
- float *ai = A.data();
622
+ float* ai = A.data();
523
623
  for (int i = 0; i < d_out; i++) {
524
- float factor = pow(eigenvalues[i], eigen_power);
525
- for(int j = 0; j < d_in; j++)
624
+ float factor = pow(eigenvalues[i] + epsilon, eigen_power);
625
+ for (int j = 0; j < d_in; j++)
526
626
  *ai++ *= factor;
527
627
  }
528
628
  }
529
629
 
530
630
  if (balanced_bins != 0) {
531
- FAISS_THROW_IF_NOT (d_out % balanced_bins == 0);
631
+ FAISS_THROW_IF_NOT(d_out % balanced_bins == 0);
532
632
  int dsub = d_out / balanced_bins;
533
- std::vector <float> Ain;
633
+ std::vector<float> Ain;
534
634
  std::swap(A, Ain);
535
635
  A.resize(d_out * d_in);
536
636
 
537
- std::vector <float> accu(balanced_bins);
538
- std::vector <int> counter(balanced_bins);
637
+ std::vector<float> accu(balanced_bins);
638
+ std::vector<int> counter(balanced_bins);
539
639
 
540
640
  // greedy assignment
541
641
  for (int i = 0; i < d_out; i++) {
@@ -550,9 +650,8 @@ void PCAMatrix::prepare_Ab ()
550
650
  }
551
651
  int row_dst = best_j * dsub + counter[best_j];
552
652
  accu[best_j] += eigenvalues[i];
553
- counter[best_j] ++;
554
- memcpy (&A[row_dst * d_in], &Ain[i * d_in],
555
- d_in * sizeof (A[0]));
653
+ counter[best_j]++;
654
+ memcpy(&A[row_dst * d_in], &Ain[i * d_in], d_in * sizeof(A[0]));
556
655
  }
557
656
 
558
657
  if (verbose) {
@@ -563,11 +662,11 @@ void PCAMatrix::prepare_Ab ()
563
662
  }
564
663
  }
565
664
 
566
-
567
665
  } else {
568
- FAISS_THROW_IF_NOT_MSG (balanced_bins == 0,
569
- "both balancing bins and applying a random rotation "
570
- "does not make sense");
666
+ FAISS_THROW_IF_NOT_MSG(
667
+ balanced_bins == 0,
668
+ "both balancing bins and applying a random rotation "
669
+ "does not make sense");
571
670
  RandomRotationMatrix rr(d_out, d_out);
572
671
 
573
672
  rr.init(5);
@@ -576,8 +675,8 @@ void PCAMatrix::prepare_Ab ()
576
675
  if (eigen_power != 0) {
577
676
  for (int i = 0; i < d_out; i++) {
578
677
  float factor = pow(eigenvalues[i], eigen_power);
579
- for(int j = 0; j < d_out; j++)
580
- rr.A[j * d_out + i] *= factor;
678
+ for (int j = 0; j < d_out; j++)
679
+ rr.A[j * d_out + i] *= factor;
581
680
  }
582
681
  }
583
682
 
@@ -586,15 +685,24 @@ void PCAMatrix::prepare_Ab ()
586
685
  FINTEGER dii = d_in, doo = d_out;
587
686
  float one = 1.0, zero = 0.0;
588
687
 
589
- sgemm_ ("Not", "Not", &dii, &doo, &doo,
590
- &one, PCAMat.data(), &dii, rr.A.data(), &doo, &zero,
591
- A.data(), &dii);
592
-
688
+ sgemm_("Not",
689
+ "Not",
690
+ &dii,
691
+ &doo,
692
+ &doo,
693
+ &one,
694
+ PCAMat.data(),
695
+ &dii,
696
+ rr.A.data(),
697
+ &doo,
698
+ &zero,
699
+ A.data(),
700
+ &dii);
593
701
  }
594
-
595
702
  }
596
703
 
597
- b.clear(); b.resize(d_out);
704
+ b.clear();
705
+ b.resize(d_out);
598
706
 
599
707
  for (int i = 0; i < d_out; i++) {
600
708
  float accu = 0;
@@ -604,57 +712,61 @@ void PCAMatrix::prepare_Ab ()
604
712
  }
605
713
 
606
714
  is_orthonormal = eigen_power == 0;
607
-
608
715
  }
609
716
 
610
717
  /*********************************************
611
718
  * ITQMatrix
612
719
  *********************************************/
613
720
 
614
- ITQMatrix::ITQMatrix (int d):
615
- LinearTransform(d, d, false),
616
- max_iter (50),
617
- seed (123)
618
- {
619
- }
620
-
721
+ ITQMatrix::ITQMatrix(int d)
722
+ : LinearTransform(d, d, false), max_iter(50), seed(123) {}
621
723
 
622
724
  /** translated from fbcode/deeplearning/catalyzer/catalyzer/quantizers.py */
623
- void ITQMatrix::train (Index::idx_t n, const float* xf)
624
- {
725
+ void ITQMatrix::train(Index::idx_t n, const float* xf) {
625
726
  size_t d = d_in;
626
- std::vector<double> rotation (d * d);
727
+ std::vector<double> rotation(d * d);
627
728
 
628
729
  if (init_rotation.size() == d * d) {
629
- memcpy (rotation.data(), init_rotation.data(),
630
- d * d * sizeof(rotation[0]));
730
+ memcpy(rotation.data(),
731
+ init_rotation.data(),
732
+ d * d * sizeof(rotation[0]));
631
733
  } else {
632
- RandomRotationMatrix rrot (d, d);
633
- rrot.init (seed);
734
+ RandomRotationMatrix rrot(d, d);
735
+ rrot.init(seed);
634
736
  for (size_t i = 0; i < d * d; i++) {
635
737
  rotation[i] = rrot.A[i];
636
738
  }
637
739
  }
638
740
 
639
- std::vector<double> x (n * d);
741
+ std::vector<double> x(n * d);
640
742
 
641
743
  for (size_t i = 0; i < n * d; i++) {
642
744
  x[i] = xf[i];
643
745
  }
644
746
 
645
- std::vector<double> rotated_x (n * d), cov_mat (d * d);
646
- std::vector<double> u (d * d), vt (d * d), singvals (d);
747
+ std::vector<double> rotated_x(n * d), cov_mat(d * d);
748
+ std::vector<double> u(d * d), vt(d * d), singvals(d);
647
749
 
648
750
  for (int i = 0; i < max_iter; i++) {
649
- print_if_verbose ("rotation", rotation, d, d);
751
+ print_if_verbose("rotation", rotation, d, d);
650
752
  { // rotated_data = np.dot(training_data, rotation)
651
753
  FINTEGER di = d, ni = n;
652
754
  double one = 1, zero = 0;
653
- dgemm_ ("N", "N", &di, &ni, &di,
654
- &one, rotation.data(), &di, x.data(), &di,
655
- &zero, rotated_x.data(), &di);
755
+ dgemm_("N",
756
+ "N",
757
+ &di,
758
+ &ni,
759
+ &di,
760
+ &one,
761
+ rotation.data(),
762
+ &di,
763
+ x.data(),
764
+ &di,
765
+ &zero,
766
+ rotated_x.data(),
767
+ &di);
656
768
  }
657
- print_if_verbose ("rotated_x", rotated_x, n, d);
769
+ print_if_verbose("rotated_x", rotated_x, n, d);
658
770
  // binarize
659
771
  for (size_t j = 0; j < n * d; j++) {
660
772
  rotated_x[j] = rotated_x[j] < 0 ? -1 : 1;
@@ -663,88 +775,119 @@ void ITQMatrix::train (Index::idx_t n, const float* xf)
663
775
  { // rotated_data = np.dot(training_data, rotation)
664
776
  FINTEGER di = d, ni = n;
665
777
  double one = 1, zero = 0;
666
- dgemm_ ("N", "T", &di, &di, &ni,
667
- &one, rotated_x.data(), &di, x.data(), &di,
668
- &zero, cov_mat.data(), &di);
778
+ dgemm_("N",
779
+ "T",
780
+ &di,
781
+ &di,
782
+ &ni,
783
+ &one,
784
+ rotated_x.data(),
785
+ &di,
786
+ x.data(),
787
+ &di,
788
+ &zero,
789
+ cov_mat.data(),
790
+ &di);
669
791
  }
670
- print_if_verbose ("cov_mat", cov_mat, d, d);
792
+ print_if_verbose("cov_mat", cov_mat, d, d);
671
793
  // SVD
672
794
  {
673
-
674
795
  FINTEGER di = d;
675
796
  FINTEGER lwork = -1, info;
676
797
  double lwork1;
677
798
 
678
799
  // workspace query
679
- dgesvd_ ("A", "A", &di, &di, cov_mat.data(), &di,
680
- singvals.data(), u.data(), &di,
681
- vt.data(), &di,
682
- &lwork1, &lwork, &info);
683
-
684
- FAISS_THROW_IF_NOT (info == 0);
685
- lwork = size_t (lwork1);
686
- std::vector<double> work (lwork);
687
- dgesvd_ ("A", "A", &di, &di, cov_mat.data(), &di,
688
- singvals.data(), u.data(), &di,
689
- vt.data(), &di,
690
- work.data(), &lwork, &info);
691
- FAISS_THROW_IF_NOT_FMT (info == 0, "sgesvd returned info=%d", info);
692
-
800
+ dgesvd_("A",
801
+ "A",
802
+ &di,
803
+ &di,
804
+ cov_mat.data(),
805
+ &di,
806
+ singvals.data(),
807
+ u.data(),
808
+ &di,
809
+ vt.data(),
810
+ &di,
811
+ &lwork1,
812
+ &lwork,
813
+ &info);
814
+
815
+ FAISS_THROW_IF_NOT(info == 0);
816
+ lwork = size_t(lwork1);
817
+ std::vector<double> work(lwork);
818
+ dgesvd_("A",
819
+ "A",
820
+ &di,
821
+ &di,
822
+ cov_mat.data(),
823
+ &di,
824
+ singvals.data(),
825
+ u.data(),
826
+ &di,
827
+ vt.data(),
828
+ &di,
829
+ work.data(),
830
+ &lwork,
831
+ &info);
832
+ FAISS_THROW_IF_NOT_FMT(info == 0, "sgesvd returned info=%d", info);
693
833
  }
694
- print_if_verbose ("u", u, d, d);
695
- print_if_verbose ("vt", vt, d, d);
834
+ print_if_verbose("u", u, d, d);
835
+ print_if_verbose("vt", vt, d, d);
696
836
  // update rotation
697
837
  {
698
838
  FINTEGER di = d;
699
839
  double one = 1, zero = 0;
700
- dgemm_ ("N", "T", &di, &di, &di,
701
- &one, u.data(), &di, vt.data(), &di,
702
- &zero, rotation.data(), &di);
840
+ dgemm_("N",
841
+ "T",
842
+ &di,
843
+ &di,
844
+ &di,
845
+ &one,
846
+ u.data(),
847
+ &di,
848
+ vt.data(),
849
+ &di,
850
+ &zero,
851
+ rotation.data(),
852
+ &di);
703
853
  }
704
- print_if_verbose ("final rot", rotation, d, d);
705
-
854
+ print_if_verbose("final rot", rotation, d, d);
706
855
  }
707
- A.resize (d * d);
856
+ A.resize(d * d);
708
857
  for (size_t i = 0; i < d; i++) {
709
858
  for (size_t j = 0; j < d; j++) {
710
859
  A[i + d * j] = rotation[j + d * i];
711
860
  }
712
861
  }
713
862
  is_trained = true;
714
-
715
863
  }
716
864
 
717
- ITQTransform::ITQTransform (int d_in, int d_out, bool do_pca):
718
- VectorTransform (d_in, d_out),
719
- do_pca (do_pca),
720
- itq (d_out),
721
- pca_then_itq (d_in, d_out, false)
722
- {
865
+ ITQTransform::ITQTransform(int d_in, int d_out, bool do_pca)
866
+ : VectorTransform(d_in, d_out),
867
+ do_pca(do_pca),
868
+ itq(d_out),
869
+ pca_then_itq(d_in, d_out, false) {
723
870
  if (!do_pca) {
724
- FAISS_THROW_IF_NOT (d_in == d_out);
871
+ FAISS_THROW_IF_NOT(d_in == d_out);
725
872
  }
726
873
  max_train_per_dim = 10;
727
874
  is_trained = false;
728
875
  }
729
876
 
877
+ void ITQTransform::train(idx_t n, const float* x) {
878
+ FAISS_THROW_IF_NOT(!is_trained);
730
879
 
731
-
732
-
733
- void ITQTransform::train (idx_t n, const float *x)
734
- {
735
- FAISS_THROW_IF_NOT (!is_trained);
736
-
737
- const float * x_in = x;
880
+ const float* x_in = x;
738
881
  size_t max_train_points = std::max(d_in * max_train_per_dim, 32768);
739
- x = fvecs_maybe_subsample (d_in, (size_t*)&n, max_train_points, x);
882
+ x = fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x);
740
883
 
741
- ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
884
+ ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
742
885
 
743
- std::unique_ptr<float []> x_norm(new float[n * d_in]);
886
+ std::unique_ptr<float[]> x_norm(new float[n * d_in]);
744
887
  { // normalize
745
888
  int d = d_in;
746
889
 
747
- mean.resize (d, 0);
890
+ mean.resize(d, 0);
748
891
  for (idx_t i = 0; i < n; i++) {
749
892
  for (idx_t j = 0; j < d; j++) {
750
893
  mean[j] += x[i * d + j];
@@ -755,38 +898,47 @@ void ITQTransform::train (idx_t n, const float *x)
755
898
  }
756
899
  for (idx_t i = 0; i < n; i++) {
757
900
  for (idx_t j = 0; j < d; j++) {
758
- x_norm[i * d + j] = x[i * d + j] - mean[j];
901
+ x_norm[i * d + j] = x[i * d + j] - mean[j];
759
902
  }
760
903
  }
761
- fvec_renorm_L2 (d_in, n, x_norm.get());
904
+ fvec_renorm_L2(d_in, n, x_norm.get());
762
905
  }
763
906
 
764
907
  // train PCA
765
908
 
766
- PCAMatrix pca (d_in, d_out);
767
- float *x_pca;
768
- std::unique_ptr<float []> x_pca_del;
909
+ PCAMatrix pca(d_in, d_out);
910
+ float* x_pca;
911
+ std::unique_ptr<float[]> x_pca_del;
769
912
  if (do_pca) {
770
- pca.have_bias = false; // for consistency with reference implem
771
- pca.train (n, x_norm.get());
772
- x_pca = pca.apply (n, x_norm.get());
913
+ pca.have_bias = false; // for consistency with reference implem
914
+ pca.train(n, x_norm.get());
915
+ x_pca = pca.apply(n, x_norm.get());
773
916
  x_pca_del.reset(x_pca);
774
917
  } else {
775
918
  x_pca = x_norm.get();
776
919
  }
777
920
 
778
921
  // train ITQ
779
- itq.train (n, x_pca);
922
+ itq.train(n, x_pca);
780
923
 
781
924
  // merge PCA and ITQ
782
925
  if (do_pca) {
783
926
  FINTEGER di = d_out, dini = d_in;
784
927
  float one = 1, zero = 0;
785
928
  pca_then_itq.A.resize(d_in * d_out);
786
- sgemm_ ("N", "N", &dini, &di, &di,
787
- &one, pca.A.data(), &dini,
788
- itq.A.data(), &di,
789
- &zero, pca_then_itq.A.data(), &dini);
929
+ sgemm_("N",
930
+ "N",
931
+ &dini,
932
+ &di,
933
+ &di,
934
+ &one,
935
+ pca.A.data(),
936
+ &dini,
937
+ itq.A.data(),
938
+ &di,
939
+ &zero,
940
+ pca_then_itq.A.data(),
941
+ &dini);
790
942
  } else {
791
943
  pca_then_itq.A = itq.A;
792
944
  }
@@ -794,12 +946,11 @@ void ITQTransform::train (idx_t n, const float *x)
794
946
  is_trained = true;
795
947
  }
796
948
 
797
- void ITQTransform::apply_noalloc (Index::idx_t n, const float * x,
798
- float * xt) const
799
- {
949
+ void ITQTransform::apply_noalloc(Index::idx_t n, const float* x, float* xt)
950
+ const {
800
951
  FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
801
952
 
802
- std::unique_ptr<float []> x_norm(new float[n * d_in]);
953
+ std::unique_ptr<float[]> x_norm(new float[n * d_in]);
803
954
  { // normalize
804
955
  int d = d_in;
805
956
  for (idx_t i = 0; i < n; i++) {
@@ -809,41 +960,36 @@ void ITQTransform::apply_noalloc (Index::idx_t n, const float * x,
809
960
  }
810
961
  // this is not really useful if we are going to binarize right
811
962
  // afterwards but OK
812
- fvec_renorm_L2 (d_in, n, x_norm.get());
963
+ fvec_renorm_L2(d_in, n, x_norm.get());
813
964
  }
814
965
 
815
- pca_then_itq.apply_noalloc (n, x_norm.get(), xt);
966
+ pca_then_itq.apply_noalloc(n, x_norm.get(), xt);
816
967
  }
817
968
 
818
969
  /*********************************************
819
970
  * OPQMatrix
820
971
  *********************************************/
821
972
 
822
-
823
- OPQMatrix::OPQMatrix (int d, int M, int d2):
824
- LinearTransform (d, d2 == -1 ? d : d2, false), M(M),
825
- niter (50),
826
- niter_pq (4), niter_pq_0 (40),
827
- verbose(false),
828
- pq(nullptr)
829
- {
973
+ OPQMatrix::OPQMatrix(int d, int M, int d2)
974
+ : LinearTransform(d, d2 == -1 ? d : d2, false),
975
+ M(M),
976
+ niter(50),
977
+ niter_pq(4),
978
+ niter_pq_0(40),
979
+ verbose(false),
980
+ pq(nullptr) {
830
981
  is_trained = false;
831
982
  // OPQ is quite expensive to train, so set this right.
832
983
  max_train_points = 256 * 256;
833
984
  pq = nullptr;
834
985
  }
835
986
 
987
+ void OPQMatrix::train(Index::idx_t n, const float* x) {
988
+ const float* x_in = x;
836
989
 
990
+ x = fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x, verbose);
837
991
 
838
- void OPQMatrix::train (Index::idx_t n, const float *x)
839
- {
840
-
841
- const float * x_in = x;
842
-
843
- x = fvecs_maybe_subsample (d_in, (size_t*)&n,
844
- max_train_points, x, verbose);
845
-
846
- ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
992
+ ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
847
993
 
848
994
  // To support d_out > d_in, we pad input vectors with 0s to d_out
849
995
  size_t d = d_out <= d_in ? d_in : d_out;
@@ -867,22 +1013,26 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
867
1013
  #endif
868
1014
 
869
1015
  if (verbose) {
870
- printf ("OPQMatrix::train: training an OPQ rotation matrix "
871
- "for M=%d from %" PRId64 " vectors in %dD -> %dD\n",
872
- M, n, d_in, d_out);
1016
+ printf("OPQMatrix::train: training an OPQ rotation matrix "
1017
+ "for M=%d from %" PRId64 " vectors in %dD -> %dD\n",
1018
+ M,
1019
+ n,
1020
+ d_in,
1021
+ d_out);
873
1022
  }
874
1023
 
875
- std::vector<float> xtrain (n * d);
1024
+ std::vector<float> xtrain(n * d);
876
1025
  // center x
877
1026
  {
878
- std::vector<float> sum (d);
879
- const float *xi = x;
1027
+ std::vector<float> sum(d);
1028
+ const float* xi = x;
880
1029
  for (size_t i = 0; i < n; i++) {
881
1030
  for (int j = 0; j < d_in; j++)
882
- sum [j] += *xi++;
1031
+ sum[j] += *xi++;
883
1032
  }
884
- for (int i = 0; i < d; i++) sum[i] /= n;
885
- float *yi = xtrain.data();
1033
+ for (int i = 0; i < d; i++)
1034
+ sum[i] /= n;
1035
+ float* yi = xtrain.data();
886
1036
  xi = x;
887
1037
  for (size_t i = 0; i < n; i++) {
888
1038
  for (int j = 0; j < d_in; j++)
@@ -890,71 +1040,80 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
890
1040
  yi += d - d_in;
891
1041
  }
892
1042
  }
893
- float *rotation;
1043
+ float* rotation;
894
1044
 
895
- if (A.size () == 0) {
896
- A.resize (d * d);
1045
+ if (A.size() == 0) {
1046
+ A.resize(d * d);
897
1047
  rotation = A.data();
898
1048
  if (verbose)
899
1049
  printf(" OPQMatrix::train: making random %zd*%zd rotation\n",
900
- d, d);
901
- float_randn (rotation, d * d, 1234);
902
- matrix_qr (d, d, rotation);
1050
+ d,
1051
+ d);
1052
+ float_randn(rotation, d * d, 1234);
1053
+ matrix_qr(d, d, rotation);
903
1054
  // we use only the d * d2 upper part of the matrix
904
- A.resize (d * d2);
1055
+ A.resize(d * d2);
905
1056
  } else {
906
- FAISS_THROW_IF_NOT (A.size() == d * d2);
1057
+ FAISS_THROW_IF_NOT(A.size() == d * d2);
907
1058
  rotation = A.data();
908
1059
  }
909
1060
 
910
- std::vector<float>
911
- xproj (d2 * n), pq_recons (d2 * n), xxr (d * n),
912
- tmp(d * d * 4);
913
-
1061
+ std::vector<float> xproj(d2 * n), pq_recons(d2 * n), xxr(d * n),
1062
+ tmp(d * d * 4);
914
1063
 
915
- ProductQuantizer pq_default (d2, M, 8);
916
- ProductQuantizer &pq_regular = pq ? *pq : pq_default;
917
- std::vector<uint8_t> codes (pq_regular.code_size * n);
1064
+ ProductQuantizer pq_default(d2, M, 8);
1065
+ ProductQuantizer& pq_regular = pq ? *pq : pq_default;
1066
+ std::vector<uint8_t> codes(pq_regular.code_size * n);
918
1067
 
919
1068
  double t0 = getmillisecs();
920
1069
  for (int iter = 0; iter < niter; iter++) {
921
-
922
1070
  { // torch.mm(xtrain, rotation:t())
923
1071
  FINTEGER di = d, d2i = d2, ni = n;
924
1072
  float zero = 0, one = 1;
925
- sgemm_ ("Transposed", "Not transposed",
926
- &d2i, &ni, &di,
927
- &one, rotation, &di,
928
- xtrain.data(), &di,
929
- &zero, xproj.data(), &d2i);
1073
+ sgemm_("Transposed",
1074
+ "Not transposed",
1075
+ &d2i,
1076
+ &ni,
1077
+ &di,
1078
+ &one,
1079
+ rotation,
1080
+ &di,
1081
+ xtrain.data(),
1082
+ &di,
1083
+ &zero,
1084
+ xproj.data(),
1085
+ &d2i);
930
1086
  }
931
1087
 
932
1088
  pq_regular.cp.max_points_per_centroid = 1000;
933
1089
  pq_regular.cp.niter = iter == 0 ? niter_pq_0 : niter_pq;
934
1090
  pq_regular.verbose = verbose;
935
- pq_regular.train (n, xproj.data());
1091
+ pq_regular.train(n, xproj.data());
936
1092
 
937
1093
  if (verbose) {
938
1094
  printf(" encode / decode\n");
939
1095
  }
940
1096
  if (pq_regular.assign_index) {
941
- pq_regular.compute_codes_with_assign_index
942
- (xproj.data(), codes.data(), n);
1097
+ pq_regular.compute_codes_with_assign_index(
1098
+ xproj.data(), codes.data(), n);
943
1099
  } else {
944
- pq_regular.compute_codes (xproj.data(), codes.data(), n);
1100
+ pq_regular.compute_codes(xproj.data(), codes.data(), n);
945
1101
  }
946
- pq_regular.decode (codes.data(), pq_recons.data(), n);
1102
+ pq_regular.decode(codes.data(), pq_recons.data(), n);
947
1103
 
948
- float pq_err = fvec_L2sqr (pq_recons.data(), xproj.data(), n * d2) / n;
1104
+ float pq_err = fvec_L2sqr(pq_recons.data(), xproj.data(), n * d2) / n;
949
1105
 
950
1106
  if (verbose)
951
- printf (" Iteration %d (%d PQ iterations):"
952
- "%.3f s, obj=%g\n", iter, pq_regular.cp.niter,
953
- (getmillisecs () - t0) / 1000.0, pq_err);
1107
+ printf(" Iteration %d (%d PQ iterations):"
1108
+ "%.3f s, obj=%g\n",
1109
+ iter,
1110
+ pq_regular.cp.niter,
1111
+ (getmillisecs() - t0) / 1000.0,
1112
+ pq_err);
954
1113
 
955
1114
  {
956
- float *u = tmp.data(), *vt = &tmp [d * d];
957
- float *sing_val = &tmp [2 * d * d];
1115
+ float *u = tmp.data(), *vt = &tmp[d * d];
1116
+ float* sing_val = &tmp[2 * d * d];
958
1117
  FINTEGER di = d, d2i = d2, ni = n;
959
1118
  float one = 1, zero = 0;
960
1119
 
@@ -962,36 +1121,69 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
962
1121
  printf(" X * recons\n");
963
1122
  }
964
1123
  // torch.mm(xtrain:t(), pq_recons)
965
- sgemm_ ("Not", "Transposed",
966
- &d2i, &di, &ni,
967
- &one, pq_recons.data(), &d2i,
968
- xtrain.data(), &di,
969
- &zero, xxr.data(), &d2i);
970
-
1124
+ sgemm_("Not",
1125
+ "Transposed",
1126
+ &d2i,
1127
+ &di,
1128
+ &ni,
1129
+ &one,
1130
+ pq_recons.data(),
1131
+ &d2i,
1132
+ xtrain.data(),
1133
+ &di,
1134
+ &zero,
1135
+ xxr.data(),
1136
+ &d2i);
971
1137
 
972
1138
  FINTEGER lwork = -1, info = -1;
973
1139
  float worksz;
974
1140
  // workspace query
975
- sgesvd_ ("All", "All",
976
- &d2i, &di, xxr.data(), &d2i,
977
- sing_val,
978
- vt, &d2i, u, &di,
979
- &worksz, &lwork, &info);
1141
+ sgesvd_("All",
1142
+ "All",
1143
+ &d2i,
1144
+ &di,
1145
+ xxr.data(),
1146
+ &d2i,
1147
+ sing_val,
1148
+ vt,
1149
+ &d2i,
1150
+ u,
1151
+ &di,
1152
+ &worksz,
1153
+ &lwork,
1154
+ &info);
980
1155
 
981
1156
  lwork = int(worksz);
982
- std::vector<float> work (lwork);
1157
+ std::vector<float> work(lwork);
983
1158
  // u and vt swapped
984
- sgesvd_ ("All", "All",
985
- &d2i, &di, xxr.data(), &d2i,
986
- sing_val,
987
- vt, &d2i, u, &di,
988
- work.data(), &lwork, &info);
989
-
990
- sgemm_ ("Transposed", "Transposed",
991
- &di, &d2i, &d2i,
992
- &one, u, &di, vt, &d2i,
993
- &zero, rotation, &di);
994
-
1159
+ sgesvd_("All",
1160
+ "All",
1161
+ &d2i,
1162
+ &di,
1163
+ xxr.data(),
1164
+ &d2i,
1165
+ sing_val,
1166
+ vt,
1167
+ &d2i,
1168
+ u,
1169
+ &di,
1170
+ work.data(),
1171
+ &lwork,
1172
+ &info);
1173
+
1174
+ sgemm_("Transposed",
1175
+ "Transposed",
1176
+ &di,
1177
+ &d2i,
1178
+ &d2i,
1179
+ &one,
1180
+ u,
1181
+ &di,
1182
+ vt,
1183
+ &d2i,
1184
+ &zero,
1185
+ rotation,
1186
+ &di);
995
1187
  }
996
1188
  pq_regular.train_type = ProductQuantizer::Train_hot_start;
997
1189
  }
@@ -999,59 +1191,52 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
999
1191
  // revert A matrix
1000
1192
  if (d > d_in) {
1001
1193
  for (long i = 0; i < d_out; i++)
1002
- memmove (&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
1003
- A.resize (d_in * d_out);
1194
+ memmove(&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
1195
+ A.resize(d_in * d_out);
1004
1196
  }
1005
1197
 
1006
1198
  is_trained = true;
1007
1199
  is_orthonormal = true;
1008
1200
  }
1009
1201
 
1010
-
1011
1202
  /*********************************************
1012
1203
  * NormalizationTransform
1013
1204
  *********************************************/
1014
1205
 
1015
- NormalizationTransform::NormalizationTransform (int d, float norm):
1016
- VectorTransform (d, d), norm (norm)
1017
- {
1018
- }
1206
+ NormalizationTransform::NormalizationTransform(int d, float norm)
1207
+ : VectorTransform(d, d), norm(norm) {}
1019
1208
 
1020
- NormalizationTransform::NormalizationTransform ():
1021
- VectorTransform (-1, -1), norm (-1)
1022
- {
1023
- }
1209
+ NormalizationTransform::NormalizationTransform()
1210
+ : VectorTransform(-1, -1), norm(-1) {}
1024
1211
 
1025
- void NormalizationTransform::apply_noalloc
1026
- (idx_t n, const float* x, float* xt) const
1027
- {
1212
+ void NormalizationTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1213
+ const {
1028
1214
  if (norm == 2.0) {
1029
- memcpy (xt, x, sizeof (x[0]) * n * d_in);
1030
- fvec_renorm_L2 (d_in, n, xt);
1215
+ memcpy(xt, x, sizeof(x[0]) * n * d_in);
1216
+ fvec_renorm_L2(d_in, n, xt);
1031
1217
  } else {
1032
- FAISS_THROW_MSG ("not implemented");
1218
+ FAISS_THROW_MSG("not implemented");
1033
1219
  }
1034
1220
  }
1035
1221
 
1036
- void NormalizationTransform::reverse_transform (idx_t n, const float* xt,
1037
- float* x) const
1038
- {
1039
- memcpy (x, xt, sizeof (xt[0]) * n * d_in);
1222
+ void NormalizationTransform::reverse_transform(
1223
+ idx_t n,
1224
+ const float* xt,
1225
+ float* x) const {
1226
+ memcpy(x, xt, sizeof(xt[0]) * n * d_in);
1040
1227
  }
1041
1228
 
1042
1229
  /*********************************************
1043
1230
  * CenteringTransform
1044
1231
  *********************************************/
1045
1232
 
1046
- CenteringTransform::CenteringTransform (int d):
1047
- VectorTransform (d, d)
1048
- {
1233
+ CenteringTransform::CenteringTransform(int d) : VectorTransform(d, d) {
1049
1234
  is_trained = false;
1050
1235
  }
1051
1236
 
1052
- void CenteringTransform::train(Index::idx_t n, const float *x) {
1237
+ void CenteringTransform::train(Index::idx_t n, const float* x) {
1053
1238
  FAISS_THROW_IF_NOT_MSG(n > 0, "need at least one training vector");
1054
- mean.resize (d_in, 0);
1239
+ mean.resize(d_in, 0);
1055
1240
  for (idx_t i = 0; i < n; i++) {
1056
1241
  for (size_t j = 0; j < d_in; j++) {
1057
1242
  mean[j] += *x++;
@@ -1064,11 +1249,9 @@ void CenteringTransform::train(Index::idx_t n, const float *x) {
1064
1249
  is_trained = true;
1065
1250
  }
1066
1251
 
1067
-
1068
- void CenteringTransform::apply_noalloc
1069
- (idx_t n, const float* x, float* xt) const
1070
- {
1071
- FAISS_THROW_IF_NOT (is_trained);
1252
+ void CenteringTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1253
+ const {
1254
+ FAISS_THROW_IF_NOT(is_trained);
1072
1255
 
1073
1256
  for (idx_t i = 0; i < n; i++) {
1074
1257
  for (size_t j = 0; j < d_in; j++) {
@@ -1077,64 +1260,58 @@ void CenteringTransform::apply_noalloc
1077
1260
  }
1078
1261
  }
1079
1262
 
1080
- void CenteringTransform::reverse_transform (idx_t n, const float* xt,
1081
- float* x) const
1082
- {
1083
- FAISS_THROW_IF_NOT (is_trained);
1263
+ void CenteringTransform::reverse_transform(idx_t n, const float* xt, float* x)
1264
+ const {
1265
+ FAISS_THROW_IF_NOT(is_trained);
1084
1266
 
1085
1267
  for (idx_t i = 0; i < n; i++) {
1086
1268
  for (size_t j = 0; j < d_in; j++) {
1087
1269
  *x++ = *xt++ + mean[j];
1088
1270
  }
1089
1271
  }
1090
-
1091
1272
  }
1092
1273
 
1093
-
1094
-
1095
-
1096
-
1097
1274
  /*********************************************
1098
1275
  * RemapDimensionsTransform
1099
1276
  *********************************************/
1100
1277
 
1101
-
1102
- RemapDimensionsTransform::RemapDimensionsTransform (
1103
- int d_in, int d_out, const int *map_in):
1104
- VectorTransform (d_in, d_out)
1105
- {
1106
- map.resize (d_out);
1278
+ RemapDimensionsTransform::RemapDimensionsTransform(
1279
+ int d_in,
1280
+ int d_out,
1281
+ const int* map_in)
1282
+ : VectorTransform(d_in, d_out) {
1283
+ map.resize(d_out);
1107
1284
  for (int i = 0; i < d_out; i++) {
1108
1285
  map[i] = map_in[i];
1109
- FAISS_THROW_IF_NOT (map[i] == -1 || (map[i] >= 0 && map[i] < d_in));
1286
+ FAISS_THROW_IF_NOT(map[i] == -1 || (map[i] >= 0 && map[i] < d_in));
1110
1287
  }
1111
1288
  }
1112
1289
 
1113
- RemapDimensionsTransform::RemapDimensionsTransform (
1114
- int d_in, int d_out, bool uniform): VectorTransform (d_in, d_out)
1115
- {
1116
- map.resize (d_out, -1);
1290
+ RemapDimensionsTransform::RemapDimensionsTransform(
1291
+ int d_in,
1292
+ int d_out,
1293
+ bool uniform)
1294
+ : VectorTransform(d_in, d_out) {
1295
+ map.resize(d_out, -1);
1117
1296
 
1118
1297
  if (uniform) {
1119
1298
  if (d_in < d_out) {
1120
1299
  for (int i = 0; i < d_in; i++) {
1121
- map [i * d_out / d_in] = i;
1122
- }
1300
+ map[i * d_out / d_in] = i;
1301
+ }
1123
1302
  } else {
1124
1303
  for (int i = 0; i < d_out; i++) {
1125
- map [i] = i * d_in / d_out;
1304
+ map[i] = i * d_in / d_out;
1126
1305
  }
1127
1306
  }
1128
1307
  } else {
1129
1308
  for (int i = 0; i < d_in && i < d_out; i++)
1130
- map [i] = i;
1309
+ map[i] = i;
1131
1310
  }
1132
1311
  }
1133
1312
 
1134
-
1135
- void RemapDimensionsTransform::apply_noalloc (idx_t n, const float * x,
1136
- float *xt) const
1137
- {
1313
+ void RemapDimensionsTransform::apply_noalloc(idx_t n, const float* x, float* xt)
1314
+ const {
1138
1315
  for (idx_t i = 0; i < n; i++) {
1139
1316
  for (int j = 0; j < d_out; j++) {
1140
1317
  xt[j] = map[j] < 0 ? 0 : x[map[j]];
@@ -1144,13 +1321,15 @@ void RemapDimensionsTransform::apply_noalloc (idx_t n, const float * x,
1144
1321
  }
1145
1322
  }
1146
1323
 
1147
- void RemapDimensionsTransform::reverse_transform (idx_t n, const float * xt,
1148
- float *x) const
1149
- {
1150
- memset (x, 0, sizeof (*x) * n * d_in);
1324
+ void RemapDimensionsTransform::reverse_transform(
1325
+ idx_t n,
1326
+ const float* xt,
1327
+ float* x) const {
1328
+ memset(x, 0, sizeof(*x) * n * d_in);
1151
1329
  for (idx_t i = 0; i < n; i++) {
1152
1330
  for (int j = 0; j < d_out; j++) {
1153
- if (map[j] >= 0) x[map[j]] = xt[j];
1331
+ if (map[j] >= 0)
1332
+ x[map[j]] = xt[j];
1154
1333
  }
1155
1334
  x += d_in;
1156
1335
  xt += d_out;