faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -10,10 +10,10 @@
10
10
  #include <faiss/utils/distances.h>
11
11
 
12
12
  #include <algorithm>
13
- #include <cstdio>
14
13
  #include <cassert>
15
- #include <cstring>
16
14
  #include <cmath>
15
+ #include <cstdio>
16
+ #include <cstring>
17
17
 
18
18
  #include <omp.h>
19
19
 
@@ -21,186 +21,153 @@
21
21
  #include <faiss/impl/FaissAssert.h>
22
22
  #include <faiss/impl/ResultHandler.h>
23
23
 
24
-
25
-
26
24
  #ifndef FINTEGER
27
25
  #define FINTEGER long
28
26
  #endif
29
27
 
30
-
31
28
  extern "C" {
32
29
 
33
30
  /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
34
31
 
35
- int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
36
- n, FINTEGER *k, const float *alpha, const float *a,
37
- FINTEGER *lda, const float *b, FINTEGER *
38
- ldb, float *beta, float *c, FINTEGER *ldc);
39
-
40
-
32
+ int sgemm_(
33
+ const char* transa,
34
+ const char* transb,
35
+ FINTEGER* m,
36
+ FINTEGER* n,
37
+ FINTEGER* k,
38
+ const float* alpha,
39
+ const float* a,
40
+ FINTEGER* lda,
41
+ const float* b,
42
+ FINTEGER* ldb,
43
+ float* beta,
44
+ float* c,
45
+ FINTEGER* ldc);
41
46
  }
42
47
 
43
-
44
48
  namespace faiss {
45
49
 
46
-
47
-
48
50
  /***************************************************************************
49
51
  * Matrix/vector ops
50
52
  ***************************************************************************/
51
53
 
52
-
53
-
54
-
55
54
  /* Compute the L2 norm of a set of nx vectors */
56
- void fvec_norms_L2 (float * __restrict nr,
57
- const float * __restrict x,
58
- size_t d, size_t nx)
59
- {
60
-
55
+ void fvec_norms_L2(
56
+ float* __restrict nr,
57
+ const float* __restrict x,
58
+ size_t d,
59
+ size_t nx) {
61
60
  #pragma omp parallel for
62
61
  for (int64_t i = 0; i < nx; i++) {
63
- nr[i] = sqrtf (fvec_norm_L2sqr (x + i * d, d));
62
+ nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
64
63
  }
65
64
  }
66
65
 
67
- void fvec_norms_L2sqr (float * __restrict nr,
68
- const float * __restrict x,
69
- size_t d, size_t nx)
70
- {
66
+ void fvec_norms_L2sqr(
67
+ float* __restrict nr,
68
+ const float* __restrict x,
69
+ size_t d,
70
+ size_t nx) {
71
71
  #pragma omp parallel for
72
72
  for (int64_t i = 0; i < nx; i++)
73
- nr[i] = fvec_norm_L2sqr (x + i * d, d);
73
+ nr[i] = fvec_norm_L2sqr(x + i * d, d);
74
74
  }
75
75
 
76
-
77
-
78
- void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
79
- {
76
+ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
80
77
  #pragma omp parallel for
81
78
  for (int64_t i = 0; i < nx; i++) {
82
- float * __restrict xi = x + i * d;
79
+ float* __restrict xi = x + i * d;
83
80
 
84
- float nr = fvec_norm_L2sqr (xi, d);
81
+ float nr = fvec_norm_L2sqr(xi, d);
85
82
 
86
83
  if (nr > 0) {
87
84
  size_t j;
88
- const float inv_nr = 1.0 / sqrtf (nr);
85
+ const float inv_nr = 1.0 / sqrtf(nr);
89
86
  for (j = 0; j < d; j++)
90
87
  xi[j] *= inv_nr;
91
88
  }
92
89
  }
93
90
  }
94
91
 
95
-
96
-
97
-
98
-
99
-
100
-
101
-
102
-
103
-
104
-
105
-
106
92
  /***************************************************************************
107
93
  * KNN functions
108
94
  ***************************************************************************/
109
95
 
110
96
  namespace {
111
97
 
112
-
113
-
114
98
  /* Find the nearest neighbors for nx queries in a set of ny vectors */
115
- template<class ResultHandler>
116
- void exhaustive_inner_product_seq (
117
- const float * x,
118
- const float * y,
119
- size_t d, size_t nx, size_t ny,
120
- ResultHandler &res)
121
- {
122
- size_t check_period = InterruptCallback::get_period_hint (ny * d);
123
-
124
- check_period *= omp_get_max_threads();
125
-
99
+ template <class ResultHandler>
100
+ void exhaustive_inner_product_seq(
101
+ const float* x,
102
+ const float* y,
103
+ size_t d,
104
+ size_t nx,
105
+ size_t ny,
106
+ ResultHandler& res) {
126
107
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
108
+ int nt = std::min(int(nx), omp_get_max_threads());
127
109
 
128
- for (size_t i0 = 0; i0 < nx; i0 += check_period) {
129
- size_t i1 = std::min(i0 + check_period, nx);
130
-
131
- #pragma omp parallel
132
- {
133
- SingleResultHandler resi(res);
110
+ #pragma omp parallel num_threads(nt)
111
+ {
112
+ SingleResultHandler resi(res);
134
113
  #pragma omp for
135
- for (int64_t i = i0; i < i1; i++) {
136
- const float * x_i = x + i * d;
137
- const float * y_j = y;
114
+ for (int64_t i = 0; i < nx; i++) {
115
+ const float* x_i = x + i * d;
116
+ const float* y_j = y;
138
117
 
139
- resi.begin(i);
118
+ resi.begin(i);
140
119
 
141
- for (size_t j = 0; j < ny; j++) {
142
- float ip = fvec_inner_product (x_i, y_j, d);
143
- resi.add_result(ip, j);
144
- y_j += d;
145
- }
146
- resi.end();
120
+ for (size_t j = 0; j < ny; j++) {
121
+ float ip = fvec_inner_product(x_i, y_j, d);
122
+ resi.add_result(ip, j);
123
+ y_j += d;
147
124
  }
125
+ resi.end();
148
126
  }
149
- InterruptCallback::check ();
150
127
  }
151
-
152
128
  }
153
129
 
154
- template<class ResultHandler>
155
- void exhaustive_L2sqr_seq (
156
- const float * x,
157
- const float * y,
158
- size_t d, size_t nx, size_t ny,
159
- ResultHandler & res)
160
- {
161
-
162
- size_t check_period = InterruptCallback::get_period_hint (ny * d);
163
- check_period *= omp_get_max_threads();
130
+ template <class ResultHandler>
131
+ void exhaustive_L2sqr_seq(
132
+ const float* x,
133
+ const float* y,
134
+ size_t d,
135
+ size_t nx,
136
+ size_t ny,
137
+ ResultHandler& res) {
164
138
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
139
+ int nt = std::min(int(nx), omp_get_max_threads());
165
140
 
166
- for (size_t i0 = 0; i0 < nx; i0 += check_period) {
167
- size_t i1 = std::min(i0 + check_period, nx);
168
-
169
- #pragma omp parallel
170
- {
171
- SingleResultHandler resi(res);
141
+ #pragma omp parallel num_threads(nt)
142
+ {
143
+ SingleResultHandler resi(res);
172
144
  #pragma omp for
173
- for (int64_t i = i0; i < i1; i++) {
174
- const float * x_i = x + i * d;
175
- const float * y_j = y;
176
- resi.begin(i);
177
- for (size_t j = 0; j < ny; j++) {
178
- float disij = fvec_L2sqr (x_i, y_j, d);
179
- resi.add_result(disij, j);
180
- y_j += d;
181
- }
182
- resi.end();
145
+ for (int64_t i = 0; i < nx; i++) {
146
+ const float* x_i = x + i * d;
147
+ const float* y_j = y;
148
+ resi.begin(i);
149
+ for (size_t j = 0; j < ny; j++) {
150
+ float disij = fvec_L2sqr(x_i, y_j, d);
151
+ resi.add_result(disij, j);
152
+ y_j += d;
183
153
  }
154
+ resi.end();
184
155
  }
185
- InterruptCallback::check ();
186
156
  }
187
-
188
- };
189
-
190
-
191
-
192
-
157
+ }
193
158
 
194
159
  /** Find the nearest neighbors for nx queries in a set of ny vectors */
195
- template<class ResultHandler>
196
- void exhaustive_inner_product_blas (
197
- const float * x,
198
- const float * y,
199
- size_t d, size_t nx, size_t ny,
200
- ResultHandler & res)
201
- {
160
+ template <class ResultHandler>
161
+ void exhaustive_inner_product_blas(
162
+ const float* x,
163
+ const float* y,
164
+ size_t d,
165
+ size_t nx,
166
+ size_t ny,
167
+ ResultHandler& res) {
202
168
  // BLAS does not like empty matrices
203
- if (nx == 0 || ny == 0) return;
169
+ if (nx == 0 || ny == 0)
170
+ return;
204
171
 
205
172
  /* block sizes */
206
173
  const size_t bs_x = distance_compute_blas_query_bs;
@@ -209,86 +176,105 @@ void exhaustive_inner_product_blas (
209
176
 
210
177
  for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
211
178
  size_t i1 = i0 + bs_x;
212
- if(i1 > nx) i1 = nx;
179
+ if (i1 > nx)
180
+ i1 = nx;
213
181
 
214
182
  res.begin_multiple(i0, i1);
215
183
 
216
184
  for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
217
185
  size_t j1 = j0 + bs_y;
218
- if (j1 > ny) j1 = ny;
186
+ if (j1 > ny)
187
+ j1 = ny;
219
188
  /* compute the actual dot products */
220
189
  {
221
190
  float one = 1, zero = 0;
222
191
  FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
223
- sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
224
- y + j0 * d, &di,
225
- x + i0 * d, &di, &zero,
226
- ip_block.get(), &nyi);
192
+ sgemm_("Transpose",
193
+ "Not transpose",
194
+ &nyi,
195
+ &nxi,
196
+ &di,
197
+ &one,
198
+ y + j0 * d,
199
+ &di,
200
+ x + i0 * d,
201
+ &di,
202
+ &zero,
203
+ ip_block.get(),
204
+ &nyi);
227
205
  }
228
206
 
229
207
  res.add_results(j0, j1, ip_block.get());
230
-
231
208
  }
232
209
  res.end_multiple();
233
- InterruptCallback::check ();
234
-
210
+ InterruptCallback::check();
235
211
  }
236
212
  }
237
213
 
238
-
239
-
240
-
241
214
  // distance correction is an operator that can be applied to transform
242
215
  // the distances
243
- template<class ResultHandler>
244
- void exhaustive_L2sqr_blas (
245
- const float * x,
246
- const float * y,
247
- size_t d, size_t nx, size_t ny,
248
- ResultHandler & res,
249
- const float *y_norms = nullptr)
250
- {
216
+ template <class ResultHandler>
217
+ void exhaustive_L2sqr_blas(
218
+ const float* x,
219
+ const float* y,
220
+ size_t d,
221
+ size_t nx,
222
+ size_t ny,
223
+ ResultHandler& res,
224
+ const float* y_norms = nullptr) {
251
225
  // BLAS does not like empty matrices
252
- if (nx == 0 || ny == 0) return;
226
+ if (nx == 0 || ny == 0)
227
+ return;
253
228
 
254
229
  /* block sizes */
255
230
  const size_t bs_x = distance_compute_blas_query_bs;
256
231
  const size_t bs_y = distance_compute_blas_database_bs;
257
232
  // const size_t bs_x = 16, bs_y = 16;
258
- std::unique_ptr<float []> ip_block(new float[bs_x * bs_y]);
259
- std::unique_ptr<float []> x_norms(new float[nx]);
260
- std::unique_ptr<float []> del2;
233
+ std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
234
+ std::unique_ptr<float[]> x_norms(new float[nx]);
235
+ std::unique_ptr<float[]> del2;
261
236
 
262
- fvec_norms_L2sqr (x_norms.get(), x, d, nx);
237
+ fvec_norms_L2sqr(x_norms.get(), x, d, nx);
263
238
 
264
239
  if (!y_norms) {
265
- float *y_norms2 = new float[ny];
240
+ float* y_norms2 = new float[ny];
266
241
  del2.reset(y_norms2);
267
- fvec_norms_L2sqr (y_norms2, y, d, ny);
242
+ fvec_norms_L2sqr(y_norms2, y, d, ny);
268
243
  y_norms = y_norms2;
269
244
  }
270
245
 
271
246
  for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
272
247
  size_t i1 = i0 + bs_x;
273
- if(i1 > nx) i1 = nx;
248
+ if (i1 > nx)
249
+ i1 = nx;
274
250
 
275
251
  res.begin_multiple(i0, i1);
276
252
 
277
253
  for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
278
254
  size_t j1 = j0 + bs_y;
279
- if (j1 > ny) j1 = ny;
255
+ if (j1 > ny)
256
+ j1 = ny;
280
257
  /* compute the actual dot products */
281
258
  {
282
259
  float one = 1, zero = 0;
283
260
  FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
284
- sgemm_ ("Transpose", "Not transpose", &nyi, &nxi, &di, &one,
285
- y + j0 * d, &di,
286
- x + i0 * d, &di, &zero,
287
- ip_block.get(), &nyi);
261
+ sgemm_("Transpose",
262
+ "Not transpose",
263
+ &nyi,
264
+ &nxi,
265
+ &di,
266
+ &one,
267
+ y + j0 * d,
268
+ &di,
269
+ x + i0 * d,
270
+ &di,
271
+ &zero,
272
+ ip_block.get(),
273
+ &nyi);
288
274
  }
289
-
275
+ #pragma omp parallel for
290
276
  for (int64_t i = i0; i < i1; i++) {
291
- float *ip_line = ip_block.get() + (i - i0) * (j1 - j0);
277
+ float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
292
278
 
293
279
  for (size_t j = j0; j < j1; j++) {
294
280
  float ip = *ip_line;
@@ -296,7 +282,8 @@ void exhaustive_L2sqr_blas (
296
282
 
297
283
  // negative values can occur for identical vectors
298
284
  // due to roundoff errors
299
- if (dis < 0) dis = 0;
285
+ if (dis < 0)
286
+ dis = 0;
300
287
 
301
288
  *ip_line = dis;
302
289
  ip_line++;
@@ -305,18 +292,12 @@ void exhaustive_L2sqr_blas (
305
292
  res.add_results(j0, j1, ip_block.get());
306
293
  }
307
294
  res.end_multiple();
308
- InterruptCallback::check ();
295
+ InterruptCallback::check();
309
296
  }
310
297
  }
311
298
 
312
-
313
-
314
299
  } // anonymous namespace
315
300
 
316
-
317
-
318
-
319
-
320
301
  /*******************************************************
321
302
  * KNN driver functions
322
303
  *******************************************************/
@@ -326,268 +307,275 @@ int distance_compute_blas_query_bs = 4096;
326
307
  int distance_compute_blas_database_bs = 1024;
327
308
  int distance_compute_min_k_reservoir = 100;
328
309
 
329
- void knn_inner_product (const float * x,
330
- const float * y,
331
- size_t d, size_t nx, size_t ny,
332
- float_minheap_array_t * ha)
333
- {
310
+ void knn_inner_product(
311
+ const float* x,
312
+ const float* y,
313
+ size_t d,
314
+ size_t nx,
315
+ size_t ny,
316
+ float_minheap_array_t* ha) {
334
317
  if (ha->k < distance_compute_min_k_reservoir) {
335
318
  HeapResultHandler<CMin<float, int64_t>> res(
336
- ha->nh, ha->val, ha->ids, ha->k);
319
+ ha->nh, ha->val, ha->ids, ha->k);
337
320
  if (nx < distance_compute_blas_threshold) {
338
- exhaustive_inner_product_seq (x, y, d, nx, ny, res);
321
+ exhaustive_inner_product_seq(x, y, d, nx, ny, res);
339
322
  } else {
340
- exhaustive_inner_product_blas (x, y, d, nx, ny, res);
323
+ exhaustive_inner_product_blas(x, y, d, nx, ny, res);
341
324
  }
342
325
  } else {
343
326
  ReservoirResultHandler<CMin<float, int64_t>> res(
344
- ha->nh, ha->val, ha->ids, ha->k);
327
+ ha->nh, ha->val, ha->ids, ha->k);
345
328
  if (nx < distance_compute_blas_threshold) {
346
- exhaustive_inner_product_seq (x, y, d, nx, ny, res);
329
+ exhaustive_inner_product_seq(x, y, d, nx, ny, res);
347
330
  } else {
348
- exhaustive_inner_product_blas (x, y, d, nx, ny, res);
331
+ exhaustive_inner_product_blas(x, y, d, nx, ny, res);
349
332
  }
350
333
  }
351
334
  }
352
335
 
353
-
354
-
355
-
356
- void knn_L2sqr (
357
- const float * x,
358
- const float * y,
359
- size_t d, size_t nx, size_t ny,
360
- float_maxheap_array_t * ha,
361
- const float *y_norm2
362
- ) {
363
-
336
+ void knn_L2sqr(
337
+ const float* x,
338
+ const float* y,
339
+ size_t d,
340
+ size_t nx,
341
+ size_t ny,
342
+ float_maxheap_array_t* ha,
343
+ const float* y_norm2) {
364
344
  if (ha->k < distance_compute_min_k_reservoir) {
365
345
  HeapResultHandler<CMax<float, int64_t>> res(
366
- ha->nh, ha->val, ha->ids, ha->k);
346
+ ha->nh, ha->val, ha->ids, ha->k);
367
347
 
368
348
  if (nx < distance_compute_blas_threshold) {
369
- exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
349
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
370
350
  } else {
371
- exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
351
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
372
352
  }
373
353
  } else {
374
354
  ReservoirResultHandler<CMax<float, int64_t>> res(
375
- ha->nh, ha->val, ha->ids, ha->k);
355
+ ha->nh, ha->val, ha->ids, ha->k);
376
356
  if (nx < distance_compute_blas_threshold) {
377
- exhaustive_L2sqr_seq (x, y, d, nx, ny, res);
357
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
378
358
  } else {
379
- exhaustive_L2sqr_blas (x, y, d, nx, ny, res, y_norm2);
359
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
380
360
  }
381
361
  }
382
362
  }
383
363
 
384
-
385
364
  /***************************************************************************
386
365
  * Range search
387
366
  ***************************************************************************/
388
367
 
389
-
390
-
391
-
392
- void range_search_L2sqr (
393
- const float * x,
394
- const float * y,
395
- size_t d, size_t nx, size_t ny,
368
+ void range_search_L2sqr(
369
+ const float* x,
370
+ const float* y,
371
+ size_t d,
372
+ size_t nx,
373
+ size_t ny,
396
374
  float radius,
397
- RangeSearchResult *res)
398
- {
375
+ RangeSearchResult* res) {
399
376
  RangeSearchResultHandler<CMax<float, int64_t>> resh(res, radius);
400
377
  if (nx < distance_compute_blas_threshold) {
401
- exhaustive_L2sqr_seq (x, y, d, nx, ny, resh);
378
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, resh);
402
379
  } else {
403
- exhaustive_L2sqr_blas (x, y, d, nx, ny, resh);
380
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
404
381
  }
405
382
  }
406
383
 
407
- void range_search_inner_product (
408
- const float * x,
409
- const float * y,
410
- size_t d, size_t nx, size_t ny,
384
+ void range_search_inner_product(
385
+ const float* x,
386
+ const float* y,
387
+ size_t d,
388
+ size_t nx,
389
+ size_t ny,
411
390
  float radius,
412
- RangeSearchResult *res)
413
- {
414
-
391
+ RangeSearchResult* res) {
415
392
  RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
416
393
  if (nx < distance_compute_blas_threshold) {
417
- exhaustive_inner_product_seq (x, y, d, nx, ny, resh);
394
+ exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
418
395
  } else {
419
- exhaustive_inner_product_blas (x, y, d, nx, ny, resh);
396
+ exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
420
397
  }
421
398
  }
422
399
 
423
-
424
400
  /***************************************************************************
425
401
  * compute a subset of distances
426
402
  ***************************************************************************/
427
403
 
428
404
  /* compute the inner product between x and a subset y of ny vectors,
429
405
  whose indices are given by idy. */
430
- void fvec_inner_products_by_idx (float * __restrict ip,
431
- const float * x,
432
- const float * y,
433
- const int64_t * __restrict ids, /* for y vecs */
434
- size_t d, size_t nx, size_t ny)
435
- {
406
+ void fvec_inner_products_by_idx(
407
+ float* __restrict ip,
408
+ const float* x,
409
+ const float* y,
410
+ const int64_t* __restrict ids, /* for y vecs */
411
+ size_t d,
412
+ size_t nx,
413
+ size_t ny) {
436
414
  #pragma omp parallel for
437
415
  for (int64_t j = 0; j < nx; j++) {
438
- const int64_t * __restrict idsj = ids + j * ny;
439
- const float * xj = x + j * d;
440
- float * __restrict ipj = ip + j * ny;
416
+ const int64_t* __restrict idsj = ids + j * ny;
417
+ const float* xj = x + j * d;
418
+ float* __restrict ipj = ip + j * ny;
441
419
  for (size_t i = 0; i < ny; i++) {
442
420
  if (idsj[i] < 0)
443
421
  continue;
444
- ipj[i] = fvec_inner_product (xj, y + d * idsj[i], d);
422
+ ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
445
423
  }
446
424
  }
447
425
  }
448
426
 
449
-
450
-
451
427
  /* compute the inner product between x and a subset y of ny vectors,
452
428
  whose indices are given by idy. */
453
- void fvec_L2sqr_by_idx (float * __restrict dis,
454
- const float * x,
455
- const float * y,
456
- const int64_t * __restrict ids, /* ids of y vecs */
457
- size_t d, size_t nx, size_t ny)
458
- {
429
+ void fvec_L2sqr_by_idx(
430
+ float* __restrict dis,
431
+ const float* x,
432
+ const float* y,
433
+ const int64_t* __restrict ids, /* ids of y vecs */
434
+ size_t d,
435
+ size_t nx,
436
+ size_t ny) {
459
437
  #pragma omp parallel for
460
438
  for (int64_t j = 0; j < nx; j++) {
461
- const int64_t * __restrict idsj = ids + j * ny;
462
- const float * xj = x + j * d;
463
- float * __restrict disj = dis + j * ny;
439
+ const int64_t* __restrict idsj = ids + j * ny;
440
+ const float* xj = x + j * d;
441
+ float* __restrict disj = dis + j * ny;
464
442
  for (size_t i = 0; i < ny; i++) {
465
443
  if (idsj[i] < 0)
466
444
  continue;
467
- disj[i] = fvec_L2sqr (xj, y + d * idsj[i], d);
445
+ disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
468
446
  }
469
447
  }
470
448
  }
471
449
 
472
- void pairwise_indexed_L2sqr (
473
- size_t d, size_t n,
474
- const float * x, const int64_t *ix,
475
- const float * y, const int64_t *iy,
476
- float *dis)
477
- {
450
+ void pairwise_indexed_L2sqr(
451
+ size_t d,
452
+ size_t n,
453
+ const float* x,
454
+ const int64_t* ix,
455
+ const float* y,
456
+ const int64_t* iy,
457
+ float* dis) {
478
458
  #pragma omp parallel for
479
459
  for (int64_t j = 0; j < n; j++) {
480
460
  if (ix[j] >= 0 && iy[j] >= 0) {
481
- dis[j] = fvec_L2sqr (x + d * ix[j], y + d * iy[j], d);
461
+ dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
482
462
  }
483
463
  }
484
464
  }
485
465
 
486
- void pairwise_indexed_inner_product (
487
- size_t d, size_t n,
488
- const float * x, const int64_t *ix,
489
- const float * y, const int64_t *iy,
490
- float *dis)
491
- {
466
+ void pairwise_indexed_inner_product(
467
+ size_t d,
468
+ size_t n,
469
+ const float* x,
470
+ const int64_t* ix,
471
+ const float* y,
472
+ const int64_t* iy,
473
+ float* dis) {
492
474
  #pragma omp parallel for
493
475
  for (int64_t j = 0; j < n; j++) {
494
476
  if (ix[j] >= 0 && iy[j] >= 0) {
495
- dis[j] = fvec_inner_product (x + d * ix[j], y + d * iy[j], d);
477
+ dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
496
478
  }
497
479
  }
498
480
  }
499
481
 
500
-
501
482
  /* Find the nearest neighbors for nx queries in a set of ny vectors
502
483
  indexed by ids. May be useful for re-ranking a pre-selected vector list */
503
- void knn_inner_products_by_idx (const float * x,
504
- const float * y,
505
- const int64_t * ids,
506
- size_t d, size_t nx, size_t ny,
507
- float_minheap_array_t * res)
508
- {
484
+ void knn_inner_products_by_idx(
485
+ const float* x,
486
+ const float* y,
487
+ const int64_t* ids,
488
+ size_t d,
489
+ size_t nx,
490
+ size_t ny,
491
+ float_minheap_array_t* res) {
509
492
  size_t k = res->k;
510
493
 
511
494
  #pragma omp parallel for
512
495
  for (int64_t i = 0; i < nx; i++) {
513
- const float * x_ = x + i * d;
514
- const int64_t * idsi = ids + i * ny;
496
+ const float* x_ = x + i * d;
497
+ const int64_t* idsi = ids + i * ny;
515
498
  size_t j;
516
- float * __restrict simi = res->get_val(i);
517
- int64_t * __restrict idxi = res->get_ids (i);
518
- minheap_heapify (k, simi, idxi);
499
+ float* __restrict simi = res->get_val(i);
500
+ int64_t* __restrict idxi = res->get_ids(i);
501
+ minheap_heapify(k, simi, idxi);
519
502
 
520
503
  for (j = 0; j < ny; j++) {
521
- if (idsi[j] < 0) break;
522
- float ip = fvec_inner_product (x_, y + d * idsi[j], d);
504
+ if (idsi[j] < 0)
505
+ break;
506
+ float ip = fvec_inner_product(x_, y + d * idsi[j], d);
523
507
 
524
508
  if (ip > simi[0]) {
525
- minheap_replace_top (k, simi, idxi, ip, idsi[j]);
509
+ minheap_replace_top(k, simi, idxi, ip, idsi[j]);
526
510
  }
527
511
  }
528
- minheap_reorder (k, simi, idxi);
512
+ minheap_reorder(k, simi, idxi);
529
513
  }
530
-
531
514
  }
532
515
 
533
- void knn_L2sqr_by_idx (const float * x,
534
- const float * y,
535
- const int64_t * __restrict ids,
536
- size_t d, size_t nx, size_t ny,
537
- float_maxheap_array_t * res)
538
- {
516
+ void knn_L2sqr_by_idx(
517
+ const float* x,
518
+ const float* y,
519
+ const int64_t* __restrict ids,
520
+ size_t d,
521
+ size_t nx,
522
+ size_t ny,
523
+ float_maxheap_array_t* res) {
539
524
  size_t k = res->k;
540
525
 
541
526
  #pragma omp parallel for
542
527
  for (int64_t i = 0; i < nx; i++) {
543
- const float * x_ = x + i * d;
544
- const int64_t * __restrict idsi = ids + i * ny;
545
- float * __restrict simi = res->get_val(i);
546
- int64_t * __restrict idxi = res->get_ids (i);
547
- maxheap_heapify (res->k, simi, idxi);
528
+ const float* x_ = x + i * d;
529
+ const int64_t* __restrict idsi = ids + i * ny;
530
+ float* __restrict simi = res->get_val(i);
531
+ int64_t* __restrict idxi = res->get_ids(i);
532
+ maxheap_heapify(res->k, simi, idxi);
548
533
  for (size_t j = 0; j < ny; j++) {
549
- float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
534
+ float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
550
535
 
551
536
  if (disij < simi[0]) {
552
- maxheap_replace_top (k, simi, idxi, disij, idsi[j]);
537
+ maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
553
538
  }
554
539
  }
555
- maxheap_reorder (res->k, simi, idxi);
540
+ maxheap_reorder(res->k, simi, idxi);
556
541
  }
557
-
558
542
  }
559
543
 
560
-
561
-
562
-
563
-
564
- void pairwise_L2sqr (int64_t d,
565
- int64_t nq, const float *xq,
566
- int64_t nb, const float *xb,
567
- float *dis,
568
- int64_t ldq, int64_t ldb, int64_t ldd)
569
- {
570
- if (nq == 0 || nb == 0) return;
571
- if (ldq == -1) ldq = d;
572
- if (ldb == -1) ldb = d;
573
- if (ldd == -1) ldd = nb;
544
+ void pairwise_L2sqr(
545
+ int64_t d,
546
+ int64_t nq,
547
+ const float* xq,
548
+ int64_t nb,
549
+ const float* xb,
550
+ float* dis,
551
+ int64_t ldq,
552
+ int64_t ldb,
553
+ int64_t ldd) {
554
+ if (nq == 0 || nb == 0)
555
+ return;
556
+ if (ldq == -1)
557
+ ldq = d;
558
+ if (ldb == -1)
559
+ ldb = d;
560
+ if (ldd == -1)
561
+ ldd = nb;
574
562
 
575
563
  // store in beginning of distance matrix to avoid malloc
576
- float *b_norms = dis;
564
+ float* b_norms = dis;
577
565
 
578
566
  #pragma omp parallel for
579
567
  for (int64_t i = 0; i < nb; i++)
580
- b_norms [i] = fvec_norm_L2sqr (xb + i * ldb, d);
568
+ b_norms[i] = fvec_norm_L2sqr(xb + i * ldb, d);
581
569
 
582
570
  #pragma omp parallel for
583
571
  for (int64_t i = 1; i < nq; i++) {
584
- float q_norm = fvec_norm_L2sqr (xq + i * ldq, d);
572
+ float q_norm = fvec_norm_L2sqr(xq + i * ldq, d);
585
573
  for (int64_t j = 0; j < nb; j++)
586
- dis[i * ldd + j] = q_norm + b_norms [j];
574
+ dis[i * ldd + j] = q_norm + b_norms[j];
587
575
  }
588
576
 
589
577
  {
590
- float q_norm = fvec_norm_L2sqr (xq, d);
578
+ float q_norm = fvec_norm_L2sqr(xq, d);
591
579
  for (int64_t j = 0; j < nb; j++)
592
580
  dis[j] += q_norm;
593
581
  }
@@ -596,22 +584,28 @@ void pairwise_L2sqr (int64_t d,
596
584
  FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd;
597
585
  float one = 1.0, minus_2 = -2.0;
598
586
 
599
- sgemm_ ("Transposed", "Not transposed",
600
- &nbi, &nqi, &di,
601
- &minus_2,
602
- xb, &ldbi,
603
- xq, &ldqi,
604
- &one, dis, &lddi);
587
+ sgemm_("Transposed",
588
+ "Not transposed",
589
+ &nbi,
590
+ &nqi,
591
+ &di,
592
+ &minus_2,
593
+ xb,
594
+ &ldbi,
595
+ xq,
596
+ &ldqi,
597
+ &one,
598
+ dis,
599
+ &lddi);
605
600
  }
606
-
607
601
  }
608
602
 
609
- void inner_product_to_L2sqr(float* __restrict dis,
610
- const float* nr1,
611
- const float* nr2,
612
- size_t n1, size_t n2)
613
- {
614
-
603
+ void inner_product_to_L2sqr(
604
+ float* __restrict dis,
605
+ const float* nr1,
606
+ const float* nr2,
607
+ size_t n1,
608
+ size_t n2) {
615
609
  #pragma omp parallel for
616
610
  for (int64_t j = 0; j < n1; j++) {
617
611
  float* disj = dis + j * n2;
@@ -620,5 +614,4 @@ void inner_product_to_L2sqr(float* __restrict dis,
620
614
  }
621
615
  }
622
616
 
623
-
624
617
  } // namespace faiss