ffi-fasttext 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +44 -0
- data/.travis.yml +5 -0
- data/Gemfile +6 -0
- data/LICENSE.txt +21 -0
- data/README.md +59 -0
- data/Rakefile +19 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/ext/ffi/fasttext/Rakefile +71 -0
- data/ffi-fasttext.gemspec +40 -0
- data/lib/ffi/fasttext.rb +108 -0
- data/lib/ffi/fasttext/version.rb +5 -0
- data/vendor/fasttext/LICENSE +30 -0
- data/vendor/fasttext/PATENTS +33 -0
- data/vendor/fasttext/args.cc +250 -0
- data/vendor/fasttext/args.h +71 -0
- data/vendor/fasttext/dictionary.cc +475 -0
- data/vendor/fasttext/dictionary.h +112 -0
- data/vendor/fasttext/fasttext.cc +693 -0
- data/vendor/fasttext/fasttext.h +97 -0
- data/vendor/fasttext/ffi_fasttext.cc +66 -0
- data/vendor/fasttext/main.cc +270 -0
- data/vendor/fasttext/matrix.cc +144 -0
- data/vendor/fasttext/matrix.h +57 -0
- data/vendor/fasttext/model.cc +341 -0
- data/vendor/fasttext/model.h +110 -0
- data/vendor/fasttext/productquantizer.cc +211 -0
- data/vendor/fasttext/productquantizer.h +67 -0
- data/vendor/fasttext/qmatrix.cc +121 -0
- data/vendor/fasttext/qmatrix.h +65 -0
- data/vendor/fasttext/real.h +19 -0
- data/vendor/fasttext/utils.cc +29 -0
- data/vendor/fasttext/utils.h +25 -0
- data/vendor/fasttext/vector.cc +137 -0
- data/vendor/fasttext/vector.h +53 -0
- metadata +151 -0
| @@ -0,0 +1,211 @@ | |
| 1 | 
            +
            /**
         | 
| 2 | 
            +
             * Copyright (c) 2016-present, Facebook, Inc.
         | 
| 3 | 
            +
             * All rights reserved.
         | 
| 4 | 
            +
             *
         | 
| 5 | 
            +
             * This source code is licensed under the BSD-style license found in the
         | 
| 6 | 
            +
             * LICENSE file in the root directory of this source tree. An additional grant
         | 
| 7 | 
            +
             * of patent rights can be found in the PATENTS file in the same directory.
         | 
| 8 | 
            +
             */
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            #include "productquantizer.h"
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            #include <algorithm>
         | 
| 13 | 
            +
            #include <iostream>
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            namespace fasttext {
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            real distL2(const real* x, const real* y, int32_t d) {
         | 
| 18 | 
            +
              real dist = 0;
         | 
| 19 | 
            +
              for (auto i = 0; i < d; i++) {
         | 
| 20 | 
            +
                auto tmp = x[i] - y[i];
         | 
| 21 | 
            +
                dist += tmp * tmp;
         | 
| 22 | 
            +
              }
         | 
| 23 | 
            +
              return dist;
         | 
| 24 | 
            +
            }
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            ProductQuantizer::ProductQuantizer(int32_t dim, int32_t dsub): dim_(dim),
         | 
| 27 | 
            +
              nsubq_(dim / dsub), dsub_(dsub), centroids_(dim * ksub_), rng(seed_) {
         | 
| 28 | 
            +
              lastdsub_ = dim_ % dsub;
         | 
| 29 | 
            +
              if (lastdsub_ == 0) {lastdsub_ = dsub_;}
         | 
| 30 | 
            +
              else {nsubq_++;}
         | 
| 31 | 
            +
            }
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            const real* ProductQuantizer::get_centroids(int32_t m, uint8_t i) const {
         | 
| 34 | 
            +
              if (m == nsubq_ - 1) {return ¢roids_[m * ksub_ * dsub_ + i * lastdsub_];}
         | 
| 35 | 
            +
              return ¢roids_[(m * ksub_ + i) * dsub_];
         | 
| 36 | 
            +
            }
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            real* ProductQuantizer::get_centroids(int32_t m, uint8_t i) {
         | 
| 39 | 
            +
              if (m == nsubq_ - 1) {return ¢roids_[m * ksub_ * dsub_ + i * lastdsub_];}
         | 
| 40 | 
            +
              return ¢roids_[(m * ksub_ + i) * dsub_];
         | 
| 41 | 
            +
            }
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            real ProductQuantizer::assign_centroid(const real * x, const real* c0,
         | 
| 44 | 
            +
                                                   uint8_t* code, int32_t d) const {
         | 
| 45 | 
            +
              const real* c = c0;
         | 
| 46 | 
            +
              real dis = distL2(x, c, d);
         | 
| 47 | 
            +
              code[0] = 0;
         | 
| 48 | 
            +
              for (auto j = 1; j < ksub_; j++) {
         | 
| 49 | 
            +
                c += d;
         | 
| 50 | 
            +
                real disij = distL2(x, c, d);
         | 
| 51 | 
            +
                if (disij < dis) {
         | 
| 52 | 
            +
                  code[0] = (uint8_t) j;
         | 
| 53 | 
            +
                  dis = disij;
         | 
| 54 | 
            +
                }
         | 
| 55 | 
            +
              }
         | 
| 56 | 
            +
              return dis;
         | 
| 57 | 
            +
            }
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            void ProductQuantizer::Estep(const real* x, const real* centroids,
         | 
| 60 | 
            +
                                         uint8_t* codes, int32_t d,
         | 
| 61 | 
            +
                                         int32_t n) const {
         | 
| 62 | 
            +
              for (auto i = 0; i < n; i++) {
         | 
| 63 | 
            +
                assign_centroid(x + i * d, centroids, codes + i, d);
         | 
| 64 | 
            +
              }
         | 
| 65 | 
            +
            }
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            void ProductQuantizer::MStep(const real* x0, real* centroids,
         | 
| 68 | 
            +
                                         const uint8_t* codes,
         | 
| 69 | 
            +
                                         int32_t d, int32_t n) {
         | 
| 70 | 
            +
              std::vector<int32_t> nelts(ksub_, 0);
         | 
| 71 | 
            +
              memset(centroids, 0, sizeof(real) * d * ksub_);
         | 
| 72 | 
            +
              const real* x = x0;
         | 
| 73 | 
            +
              for (auto i = 0; i < n; i++) {
         | 
| 74 | 
            +
                auto k = codes[i];
         | 
| 75 | 
            +
                real* c = centroids + k * d;
         | 
| 76 | 
            +
                for (auto j = 0; j < d; j++) {
         | 
| 77 | 
            +
                  c[j] += x[j];
         | 
| 78 | 
            +
                }
         | 
| 79 | 
            +
                nelts[k]++;
         | 
| 80 | 
            +
                x += d;
         | 
| 81 | 
            +
              }
         | 
| 82 | 
            +
             | 
| 83 | 
            +
              real* c = centroids;
         | 
| 84 | 
            +
              for (auto k = 0; k < ksub_; k++) {
         | 
| 85 | 
            +
                real z = (real) nelts[k];
         | 
| 86 | 
            +
                if (z != 0) {
         | 
| 87 | 
            +
                  for (auto j = 0; j < d; j++) {
         | 
| 88 | 
            +
                    c[j] /= z;
         | 
| 89 | 
            +
                  }
         | 
| 90 | 
            +
                }
         | 
| 91 | 
            +
                c += d;
         | 
| 92 | 
            +
              }
         | 
| 93 | 
            +
             | 
| 94 | 
            +
              std::uniform_real_distribution<> runiform(0,1);
         | 
| 95 | 
            +
              for (auto k = 0; k < ksub_; k++) {
         | 
| 96 | 
            +
                if (nelts[k] == 0) {
         | 
| 97 | 
            +
                  int32_t m = 0;
         | 
| 98 | 
            +
                  while (runiform(rng) * (n - ksub_) >= nelts[m] - 1) {
         | 
| 99 | 
            +
                    m = (m + 1) % ksub_;
         | 
| 100 | 
            +
                  }
         | 
| 101 | 
            +
                  memcpy(centroids + k * d, centroids + m * d, sizeof(real) * d);
         | 
| 102 | 
            +
                  for (auto j = 0; j < d; j++) {
         | 
| 103 | 
            +
                    int32_t sign = (j % 2) * 2 - 1;
         | 
| 104 | 
            +
                    centroids[k * d + j] += sign * eps_;
         | 
| 105 | 
            +
                    centroids[m * d + j] -= sign * eps_;
         | 
| 106 | 
            +
                  }
         | 
| 107 | 
            +
                  nelts[k] = nelts[m] / 2;
         | 
| 108 | 
            +
                  nelts[m] -= nelts[k];
         | 
| 109 | 
            +
                }
         | 
| 110 | 
            +
              }
         | 
| 111 | 
            +
            }
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            void ProductQuantizer::kmeans(const real *x, real* c, int32_t n, int32_t d) {
         | 
| 114 | 
            +
              std::vector<int32_t> perm(n,0);
         | 
| 115 | 
            +
              std::iota(perm.begin(), perm.end(), 0);
         | 
| 116 | 
            +
              std::shuffle(perm.begin(), perm.end(), rng);
         | 
| 117 | 
            +
              for (auto i = 0; i < ksub_; i++) {
         | 
| 118 | 
            +
                memcpy (&c[i * d], x + perm[i] * d, d * sizeof(real));
         | 
| 119 | 
            +
              }
         | 
| 120 | 
            +
              uint8_t* codes = new uint8_t[n];
         | 
| 121 | 
            +
              for (auto i = 0; i < niter_; i++) {
         | 
| 122 | 
            +
                Estep(x, c, codes, d, n);
         | 
| 123 | 
            +
                MStep(x, c, codes, d, n);
         | 
| 124 | 
            +
              }
         | 
| 125 | 
            +
              delete [] codes;
         | 
| 126 | 
            +
            }
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            void ProductQuantizer::train(int32_t n, const real * x) {
         | 
| 129 | 
            +
              if (n < ksub_) {
         | 
| 130 | 
            +
                std::cerr<<"Matrix too small for quantization, must have > 256 rows"<<std::endl;
         | 
| 131 | 
            +
                exit(1);
         | 
| 132 | 
            +
              }
         | 
| 133 | 
            +
              std::vector<int32_t> perm(n, 0);
         | 
| 134 | 
            +
              std::iota(perm.begin(), perm.end(), 0);
         | 
| 135 | 
            +
              auto d = dsub_;
         | 
| 136 | 
            +
              auto np = std::min(n, max_points_);
         | 
| 137 | 
            +
              real* xslice = new real[np * dsub_];
         | 
| 138 | 
            +
              for (auto m = 0; m < nsubq_; m++) {
         | 
| 139 | 
            +
                if (m == nsubq_-1) {d = lastdsub_;}
         | 
| 140 | 
            +
                if (np != n) {std::shuffle(perm.begin(), perm.end(), rng);}
         | 
| 141 | 
            +
                for (auto j = 0; j < np; j++) {
         | 
| 142 | 
            +
                  memcpy (xslice + j * d, x + perm[j] * dim_ + m * dsub_, d * sizeof(real));
         | 
| 143 | 
            +
                }
         | 
| 144 | 
            +
                kmeans(xslice, get_centroids(m, 0), np, d);
         | 
| 145 | 
            +
              }
         | 
| 146 | 
            +
              delete [] xslice;
         | 
| 147 | 
            +
            }
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            real ProductQuantizer::mulcode(const Vector& x, const uint8_t* codes,
         | 
| 150 | 
            +
                                           int32_t t, real alpha) const {
         | 
| 151 | 
            +
              real res = 0.0;
         | 
| 152 | 
            +
              auto d = dsub_;
         | 
| 153 | 
            +
              const uint8_t* code = codes + nsubq_ * t;
         | 
| 154 | 
            +
              for (auto m = 0; m < nsubq_; m++) {
         | 
| 155 | 
            +
                const real* c = get_centroids(m, code[m]);
         | 
| 156 | 
            +
                if (m == nsubq_ - 1) {d = lastdsub_;}
         | 
| 157 | 
            +
                for(auto n = 0; n < d; n++) {
         | 
| 158 | 
            +
                  res += x[m * dsub_ + n] * c[n];
         | 
| 159 | 
            +
                }
         | 
| 160 | 
            +
              }
         | 
| 161 | 
            +
              return res * alpha;
         | 
| 162 | 
            +
            }
         | 
| 163 | 
            +
             | 
| 164 | 
            +
            void ProductQuantizer::addcode(Vector& x, const uint8_t* codes,
         | 
| 165 | 
            +
                                           int32_t t, real alpha) const {
         | 
| 166 | 
            +
              auto d = dsub_;
         | 
| 167 | 
            +
              const uint8_t* code = codes + nsubq_ * t;
         | 
| 168 | 
            +
              for (auto m = 0; m < nsubq_; m++) {
         | 
| 169 | 
            +
                const real* c = get_centroids(m, code[m]);
         | 
| 170 | 
            +
                if (m == nsubq_ - 1) {d = lastdsub_;}
         | 
| 171 | 
            +
                for(auto n = 0; n < d; n++) {
         | 
| 172 | 
            +
                  x[m * dsub_ + n] += alpha * c[n];
         | 
| 173 | 
            +
                }
         | 
| 174 | 
            +
              }
         | 
| 175 | 
            +
            }
         | 
| 176 | 
            +
             | 
| 177 | 
            +
            void ProductQuantizer::compute_code(const real* x, uint8_t* code) const {
         | 
| 178 | 
            +
              auto d = dsub_;
         | 
| 179 | 
            +
              for (auto m = 0; m < nsubq_; m++) {
         | 
| 180 | 
            +
                if (m == nsubq_ - 1) {d = lastdsub_;}
         | 
| 181 | 
            +
                assign_centroid(x + m * dsub_, get_centroids(m, 0), code + m, d);
         | 
| 182 | 
            +
              }
         | 
| 183 | 
            +
            }
         | 
| 184 | 
            +
             | 
| 185 | 
            +
            void ProductQuantizer::compute_codes(const real* x, uint8_t* codes,
         | 
| 186 | 
            +
                                                 int32_t n) const {
         | 
| 187 | 
            +
              for (auto i = 0; i < n; i++) {
         | 
| 188 | 
            +
                compute_code(x + i * dim_, codes + i * nsubq_);
         | 
| 189 | 
            +
              }
         | 
| 190 | 
            +
            }
         | 
| 191 | 
            +
             | 
| 192 | 
            +
            void ProductQuantizer::save(std::ostream& out) {
         | 
| 193 | 
            +
              out.write((char*) &dim_, sizeof(dim_));
         | 
| 194 | 
            +
              out.write((char*) &nsubq_, sizeof(nsubq_));
         | 
| 195 | 
            +
              out.write((char*) &dsub_, sizeof(dsub_));
         | 
| 196 | 
            +
              out.write((char*) &lastdsub_, sizeof(lastdsub_));
         | 
| 197 | 
            +
              out.write((char*) centroids_.data(), centroids_.size() * sizeof(real));
         | 
| 198 | 
            +
            }
         | 
| 199 | 
            +
             | 
| 200 | 
            +
            void ProductQuantizer::load(std::istream& in) {
         | 
| 201 | 
            +
              in.read((char*) &dim_, sizeof(dim_));
         | 
| 202 | 
            +
              in.read((char*) &nsubq_, sizeof(nsubq_));
         | 
| 203 | 
            +
              in.read((char*) &dsub_, sizeof(dsub_));
         | 
| 204 | 
            +
              in.read((char*) &lastdsub_, sizeof(lastdsub_));
         | 
| 205 | 
            +
              centroids_.resize(dim_ * ksub_);
         | 
| 206 | 
            +
              for (auto i=0; i < centroids_.size(); i++) {
         | 
| 207 | 
            +
                in.read((char*) ¢roids_[i], sizeof(real));
         | 
| 208 | 
            +
              }
         | 
| 209 | 
            +
            }
         | 
| 210 | 
            +
             | 
| 211 | 
            +
            }
         | 
| @@ -0,0 +1,67 @@ | |
| 1 | 
            +
            /**
         | 
| 2 | 
            +
             * Copyright (c) 2016-present, Facebook, Inc.
         | 
| 3 | 
            +
             * All rights reserved.
         | 
| 4 | 
            +
             *
         | 
| 5 | 
            +
             * This source code is licensed under the BSD-style license found in the
         | 
| 6 | 
            +
             * LICENSE file in the root directory of this source tree. An additional grant
         | 
| 7 | 
            +
             * of patent rights can be found in the PATENTS file in the same directory.
         | 
| 8 | 
            +
             */
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            #ifndef FASTTEXT_PRODUCT_QUANTIZER_H
         | 
| 11 | 
            +
            #define FASTTEXT_PRODUCT_QUANTIZER_H
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            #include <cstring>
         | 
| 14 | 
            +
            #include <istream>
         | 
| 15 | 
            +
            #include <ostream>
         | 
| 16 | 
            +
            #include <vector>
         | 
| 17 | 
            +
            #include <random>
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            #include "real.h"
         | 
| 20 | 
            +
            #include "vector.h"
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            namespace fasttext {
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            class ProductQuantizer {
         | 
| 25 | 
            +
              private:
         | 
| 26 | 
            +
                const int32_t nbits_ = 8;
         | 
| 27 | 
            +
                const int32_t ksub_ = 1 << nbits_;
         | 
| 28 | 
            +
                const int32_t max_points_per_cluster_ = 256;
         | 
| 29 | 
            +
                const int32_t max_points_ = max_points_per_cluster_ * ksub_;
         | 
| 30 | 
            +
                const int32_t seed_ = 1234;
         | 
| 31 | 
            +
                const int32_t niter_ = 25;
         | 
| 32 | 
            +
                const real eps_ = 1e-7;
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                int32_t dim_;
         | 
| 35 | 
            +
                int32_t nsubq_;
         | 
| 36 | 
            +
                int32_t dsub_;
         | 
| 37 | 
            +
                int32_t lastdsub_;
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                std::vector<real> centroids_;
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                std::minstd_rand rng;
         | 
| 42 | 
            +
             | 
| 43 | 
            +
              public:
         | 
| 44 | 
            +
                ProductQuantizer() {}
         | 
| 45 | 
            +
                ProductQuantizer(int32_t, int32_t);
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                real* get_centroids (int32_t, uint8_t);
         | 
| 48 | 
            +
                const real* get_centroids(int32_t, uint8_t) const;
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                real assign_centroid(const real*, const real*, uint8_t*, int32_t) const;
         | 
| 51 | 
            +
                void Estep(const real*, const real*, uint8_t*, int32_t, int32_t) const;
         | 
| 52 | 
            +
                void MStep(const real*, real*, const uint8_t*, int32_t, int32_t);
         | 
| 53 | 
            +
                void kmeans(const real*, real*, int32_t, int32_t);
         | 
| 54 | 
            +
                void train(int, const real*);
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                real mulcode(const Vector&, const uint8_t*, int32_t, real) const;
         | 
| 57 | 
            +
                void addcode(Vector&, const uint8_t*, int32_t, real) const;
         | 
| 58 | 
            +
                void compute_code(const real*, uint8_t*)  const;
         | 
| 59 | 
            +
                void compute_codes(const real*, uint8_t*, int32_t)  const;
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                void save(std::ostream&);
         | 
| 62 | 
            +
                void load(std::istream&);
         | 
| 63 | 
            +
            };
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            }
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            #endif
         | 
| @@ -0,0 +1,121 @@ | |
| 1 | 
            +
            /**
         | 
| 2 | 
            +
             * Copyright (c) 2016-present, Facebook, Inc.
         | 
| 3 | 
            +
             * All rights reserved.
         | 
| 4 | 
            +
             *
         | 
| 5 | 
            +
             * This source code is licensed under the BSD-style license found in the
         | 
| 6 | 
            +
             * LICENSE file in the root directory of this source tree. An additional grant
         | 
| 7 | 
            +
             * of patent rights can be found in the PATENTS file in the same directory.
         | 
| 8 | 
            +
             */
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            #include "qmatrix.h"
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            #include <assert.h>
         | 
| 13 | 
            +
            #include <iostream>
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            namespace fasttext {
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            QMatrix::QMatrix() : qnorm_(false),
         | 
| 18 | 
            +
              m_(0), n_(0), codesize_(0) {}
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            QMatrix::QMatrix(const Matrix& mat, int32_t dsub, bool qnorm)
         | 
| 21 | 
            +
                  : qnorm_(qnorm), m_(mat.m_), n_(mat.n_),
         | 
| 22 | 
            +
                    codesize_(m_ * ((n_ + dsub - 1) / dsub)) {
         | 
| 23 | 
            +
              if (codesize_ > 0) {
         | 
| 24 | 
            +
                codes_ = new uint8_t[codesize_];
         | 
| 25 | 
            +
              }
         | 
| 26 | 
            +
              pq_ = std::unique_ptr<ProductQuantizer>( new ProductQuantizer(n_, dsub));
         | 
| 27 | 
            +
              if (qnorm_) {
         | 
| 28 | 
            +
                norm_codes_ = new uint8_t[m_];
         | 
| 29 | 
            +
                npq_ = std::unique_ptr<ProductQuantizer>( new ProductQuantizer(1, 1));
         | 
| 30 | 
            +
              }
         | 
| 31 | 
            +
              quantize(mat);
         | 
| 32 | 
            +
            }
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            QMatrix::~QMatrix() {
         | 
| 35 | 
            +
              if (codesize_ > 0) {
         | 
| 36 | 
            +
                delete[] codes_;
         | 
| 37 | 
            +
              }
         | 
| 38 | 
            +
              if (qnorm_) { delete[] norm_codes_; }
         | 
| 39 | 
            +
            }
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            void QMatrix::quantizeNorm(const Vector& norms) {
         | 
| 42 | 
            +
              assert(qnorm_);
         | 
| 43 | 
            +
              assert(norms.m_ == m_);
         | 
| 44 | 
            +
              auto dataptr = norms.data_;
         | 
| 45 | 
            +
              npq_->train(m_, dataptr);
         | 
| 46 | 
            +
              npq_->compute_codes(dataptr, norm_codes_, m_);
         | 
| 47 | 
            +
            }
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            void QMatrix::quantize(const Matrix& matrix) {
         | 
| 50 | 
            +
              assert(n_ == matrix.n_);
         | 
| 51 | 
            +
              assert(m_ == matrix.m_);
         | 
| 52 | 
            +
              Matrix temp(matrix);
         | 
| 53 | 
            +
              if (qnorm_) {
         | 
| 54 | 
            +
                Vector norms(temp.m_);
         | 
| 55 | 
            +
                temp.l2NormRow(norms);
         | 
| 56 | 
            +
                temp.divideRow(norms);
         | 
| 57 | 
            +
                quantizeNorm(norms);
         | 
| 58 | 
            +
              }
         | 
| 59 | 
            +
              auto dataptr = temp.data_;
         | 
| 60 | 
            +
              pq_->train(m_, dataptr);
         | 
| 61 | 
            +
              pq_->compute_codes(dataptr, codes_, m_);
         | 
| 62 | 
            +
            }
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            void QMatrix::addToVector(Vector& x, int32_t t) const {
         | 
| 65 | 
            +
              real norm = 1;
         | 
| 66 | 
            +
              if (qnorm_) {
         | 
| 67 | 
            +
                norm = npq_->get_centroids(0, norm_codes_[t])[0];
         | 
| 68 | 
            +
              }
         | 
| 69 | 
            +
              pq_->addcode(x, codes_, t, norm);
         | 
| 70 | 
            +
            }
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            real QMatrix::dotRow(const Vector& vec, int64_t i) const {
         | 
| 73 | 
            +
              assert(i >= 0);
         | 
| 74 | 
            +
              assert(i < m_);
         | 
| 75 | 
            +
              assert(vec.size() == n_);
         | 
| 76 | 
            +
              real norm = 1;
         | 
| 77 | 
            +
              if (qnorm_) {
         | 
| 78 | 
            +
                norm = npq_->get_centroids(0, norm_codes_[i])[0];
         | 
| 79 | 
            +
              }
         | 
| 80 | 
            +
              return pq_->mulcode(vec, codes_, i, norm);
         | 
| 81 | 
            +
            }
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            int64_t QMatrix::getM() const {
         | 
| 84 | 
            +
              return m_;
         | 
| 85 | 
            +
            }
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            int64_t QMatrix::getN() const {
         | 
| 88 | 
            +
              return n_;
         | 
| 89 | 
            +
            }
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            void QMatrix::save(std::ostream& out) {
         | 
| 92 | 
            +
                out.write((char*) &qnorm_, sizeof(qnorm_));
         | 
| 93 | 
            +
                out.write((char*) &m_, sizeof(m_));
         | 
| 94 | 
            +
                out.write((char*) &n_, sizeof(n_));
         | 
| 95 | 
            +
                out.write((char*) &codesize_, sizeof(codesize_));
         | 
| 96 | 
            +
                out.write((char*) codes_, codesize_ * sizeof(uint8_t));
         | 
| 97 | 
            +
                pq_->save(out);
         | 
| 98 | 
            +
                if (qnorm_) {
         | 
| 99 | 
            +
                  out.write((char*) norm_codes_, m_ * sizeof(uint8_t));
         | 
| 100 | 
            +
                  npq_->save(out);
         | 
| 101 | 
            +
                }
         | 
| 102 | 
            +
            }
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            void QMatrix::load(std::istream& in) {
         | 
| 105 | 
            +
                in.read((char*) &qnorm_, sizeof(qnorm_));
         | 
| 106 | 
            +
                in.read((char*) &m_, sizeof(m_));
         | 
| 107 | 
            +
                in.read((char*) &n_, sizeof(n_));
         | 
| 108 | 
            +
                in.read((char*) &codesize_, sizeof(codesize_));
         | 
| 109 | 
            +
                codes_ = new uint8_t[codesize_];
         | 
| 110 | 
            +
                in.read((char*) codes_, codesize_ * sizeof(uint8_t));
         | 
| 111 | 
            +
                pq_ = std::unique_ptr<ProductQuantizer>( new ProductQuantizer());
         | 
| 112 | 
            +
                pq_->load(in);
         | 
| 113 | 
            +
                if (qnorm_) {
         | 
| 114 | 
            +
                  norm_codes_ = new uint8_t[m_];
         | 
| 115 | 
            +
                  in.read((char*) norm_codes_, m_ * sizeof(uint8_t));
         | 
| 116 | 
            +
                  npq_ = std::unique_ptr<ProductQuantizer>( new ProductQuantizer());
         | 
| 117 | 
            +
                  npq_->load(in);
         | 
| 118 | 
            +
                }
         | 
| 119 | 
            +
            }
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            }
         | 
| @@ -0,0 +1,65 @@ | |
| 1 | 
            +
            /**
         | 
| 2 | 
            +
             * Copyright (c) 2016-present, Facebook, Inc.
         | 
| 3 | 
            +
             * All rights reserved.
         | 
| 4 | 
            +
             *
         | 
| 5 | 
            +
             * This source code is licensed under the BSD-style license found in the
         | 
| 6 | 
            +
             * LICENSE file in the root directory of this source tree. An additional grant
         | 
| 7 | 
            +
             * of patent rights can be found in the PATENTS file in the same directory.
         | 
| 8 | 
            +
             */
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            #ifndef FASTTEXT_QMATRIX_H
         | 
| 11 | 
            +
            #define FASTTEXT_QMATRIX_H
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            #include <cstdint>
         | 
| 14 | 
            +
            #include <istream>
         | 
| 15 | 
            +
            #include <ostream>
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            #include <vector>
         | 
| 18 | 
            +
            #include <memory>
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            #include "real.h"
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            #include "matrix.h"
         | 
| 23 | 
            +
            #include "vector.h"
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            #include "productquantizer.h"
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            namespace fasttext {
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            class QMatrix {
         | 
| 30 | 
            +
              private:
         | 
| 31 | 
            +
                std::unique_ptr<ProductQuantizer> pq_;
         | 
| 32 | 
            +
                std::unique_ptr<ProductQuantizer> npq_;
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                uint8_t* codes_;
         | 
| 35 | 
            +
                uint8_t* norm_codes_;
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                bool qnorm_;
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                int64_t m_;
         | 
| 40 | 
            +
                int64_t n_;
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                int32_t codesize_;
         | 
| 43 | 
            +
             | 
| 44 | 
            +
              public:
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                QMatrix();
         | 
| 47 | 
            +
                QMatrix(const Matrix&, int32_t, bool);
         | 
| 48 | 
            +
                ~QMatrix();
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                int64_t getM() const;
         | 
| 51 | 
            +
                int64_t getN() const;
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                void quantizeNorm(const Vector&);
         | 
| 54 | 
            +
                void quantize(const Matrix&);
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                void addToVector(Vector& x, int32_t t) const;
         | 
| 57 | 
            +
                real dotRow(const Vector&, int64_t) const;
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                void save(std::ostream&);
         | 
| 60 | 
            +
                void load(std::istream&);
         | 
| 61 | 
            +
            };
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            }
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            #endif
         |