cld3 3.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. checksums.yaml +7 -0
  2. data/Gemfile +18 -0
  3. data/LICENSE +204 -0
  4. data/LICENSE_CLD3 +203 -0
  5. data/README.md +22 -0
  6. data/cld3.gemspec +35 -0
  7. data/ext/cld3/base.cc +36 -0
  8. data/ext/cld3/base.h +106 -0
  9. data/ext/cld3/casts.h +98 -0
  10. data/ext/cld3/embedding_feature_extractor.cc +51 -0
  11. data/ext/cld3/embedding_feature_extractor.h +182 -0
  12. data/ext/cld3/embedding_network.cc +196 -0
  13. data/ext/cld3/embedding_network.h +186 -0
  14. data/ext/cld3/embedding_network_params.h +285 -0
  15. data/ext/cld3/extconf.rb +49 -0
  16. data/ext/cld3/feature_extractor.cc +137 -0
  17. data/ext/cld3/feature_extractor.h +633 -0
  18. data/ext/cld3/feature_extractor.proto +50 -0
  19. data/ext/cld3/feature_types.cc +72 -0
  20. data/ext/cld3/feature_types.h +158 -0
  21. data/ext/cld3/fixunicodevalue.cc +55 -0
  22. data/ext/cld3/fixunicodevalue.h +69 -0
  23. data/ext/cld3/float16.h +58 -0
  24. data/ext/cld3/fml_parser.cc +308 -0
  25. data/ext/cld3/fml_parser.h +123 -0
  26. data/ext/cld3/generated_entities.cc +296 -0
  27. data/ext/cld3/generated_ulscript.cc +678 -0
  28. data/ext/cld3/generated_ulscript.h +142 -0
  29. data/ext/cld3/getonescriptspan.cc +1109 -0
  30. data/ext/cld3/getonescriptspan.h +124 -0
  31. data/ext/cld3/integral_types.h +37 -0
  32. data/ext/cld3/lang_id_nn_params.cc +57449 -0
  33. data/ext/cld3/lang_id_nn_params.h +178 -0
  34. data/ext/cld3/language_identifier_features.cc +165 -0
  35. data/ext/cld3/language_identifier_features.h +116 -0
  36. data/ext/cld3/nnet_language_identifier.cc +380 -0
  37. data/ext/cld3/nnet_language_identifier.h +175 -0
  38. data/ext/cld3/nnet_language_identifier_c.cc +72 -0
  39. data/ext/cld3/offsetmap.cc +478 -0
  40. data/ext/cld3/offsetmap.h +168 -0
  41. data/ext/cld3/port.h +143 -0
  42. data/ext/cld3/registry.cc +28 -0
  43. data/ext/cld3/registry.h +242 -0
  44. data/ext/cld3/relevant_script_feature.cc +89 -0
  45. data/ext/cld3/relevant_script_feature.h +49 -0
  46. data/ext/cld3/script_detector.h +156 -0
  47. data/ext/cld3/sentence.proto +77 -0
  48. data/ext/cld3/sentence_features.cc +29 -0
  49. data/ext/cld3/sentence_features.h +35 -0
  50. data/ext/cld3/simple_adder.h +72 -0
  51. data/ext/cld3/stringpiece.h +81 -0
  52. data/ext/cld3/task_context.cc +161 -0
  53. data/ext/cld3/task_context.h +81 -0
  54. data/ext/cld3/task_context_params.cc +74 -0
  55. data/ext/cld3/task_context_params.h +54 -0
  56. data/ext/cld3/task_spec.proto +98 -0
  57. data/ext/cld3/text_processing.cc +245 -0
  58. data/ext/cld3/text_processing.h +30 -0
  59. data/ext/cld3/unicodetext.cc +96 -0
  60. data/ext/cld3/unicodetext.h +144 -0
  61. data/ext/cld3/utf8acceptinterchange.h +486 -0
  62. data/ext/cld3/utf8prop_lettermarkscriptnum.h +1631 -0
  63. data/ext/cld3/utf8repl_lettermarklower.h +758 -0
  64. data/ext/cld3/utf8scannot_lettermarkspecial.h +1455 -0
  65. data/ext/cld3/utf8statetable.cc +1344 -0
  66. data/ext/cld3/utf8statetable.h +285 -0
  67. data/ext/cld3/utils.cc +241 -0
  68. data/ext/cld3/utils.h +144 -0
  69. data/ext/cld3/workspace.cc +64 -0
  70. data/ext/cld3/workspace.h +177 -0
  71. data/lib/cld3.rb +99 -0
  72. metadata +158 -0
@@ -0,0 +1,196 @@
1
+ /* Copyright 2016 Google Inc. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "embedding_network.h"
17
+
18
+ #include "base.h"
19
+ #include "embedding_network_params.h"
20
+ #include "float16.h"
21
+ #include "simple_adder.h"
22
+
23
+ namespace chrome_lang_id {
24
+ namespace {
25
+
26
+ using VectorWrapper = EmbeddingNetwork::VectorWrapper;
27
+
28
+ void CheckNoQuantization(const EmbeddingNetworkParams::Matrix matrix) {
29
+ // Quantization not allowed here.
30
+ CLD3_DCHECK(static_cast<int>(QuantizationType::NONE) ==
31
+ static_cast<int>(matrix.quant_type));
32
+ }
33
+
34
+ // Fills a Matrix object with the parameters in the given MatrixParams. This
35
+ // function is used to initialize weight matrices that are *not* embedding
36
+ // matrices.
37
+ void FillMatrixParams(const EmbeddingNetworkParams::Matrix source_matrix,
38
+ EmbeddingNetwork::Matrix *mat) {
39
+ mat->resize(source_matrix.rows);
40
+ CheckNoQuantization(source_matrix);
41
+ const float *weights =
42
+ reinterpret_cast<const float *>(source_matrix.elements);
43
+ for (int r = 0; r < source_matrix.rows; ++r) {
44
+ (*mat)[r] = EmbeddingNetwork::VectorWrapper(weights, source_matrix.cols);
45
+ weights += source_matrix.cols;
46
+ }
47
+ }
48
+
49
+ // Computes y = weights * Relu(x) + b where Relu is optionally applied.
50
+ template <typename ScaleAdderClass>
51
+ void SparseReluProductPlusBias(bool apply_relu,
52
+ const EmbeddingNetwork::Matrix &weights,
53
+ const EmbeddingNetwork::VectorWrapper &b,
54
+ const EmbeddingNetwork::Vector &x,
55
+ EmbeddingNetwork::Vector *y) {
56
+ y->assign(b.data(), b.data() + b.size());
57
+ ScaleAdderClass adder(y->data(), y->size());
58
+
59
+ const int x_size = x.size();
60
+ for (int i = 0; i < x_size; ++i) {
61
+ const float &scale = x[i];
62
+ if (apply_relu) {
63
+ if (scale > 0) {
64
+ adder.LazyScaleAdd(weights[i].data(), scale);
65
+ }
66
+ } else {
67
+ adder.LazyScaleAdd(weights[i].data(), scale);
68
+ }
69
+ }
70
+ adder.Finalize();
71
+ }
72
+ } // namespace
73
+
74
+ void EmbeddingNetwork::ConcatEmbeddings(
75
+ const std::vector<FeatureVector> &feature_vectors, Vector *concat) const {
76
+ concat->resize(model_->concat_layer_size());
77
+
78
+ // "es_index" stands for "embedding space index".
79
+ for (size_t es_index = 0; es_index < feature_vectors.size(); ++es_index) {
80
+ const int concat_offset = model_->concat_offset(es_index);
81
+ const int embedding_dim = model_->embedding_dim(es_index);
82
+
83
+ const EmbeddingMatrix &embedding_matrix = embedding_matrices_[es_index];
84
+ CLD3_DCHECK(embedding_matrix.dim() == embedding_dim);
85
+
86
+ const bool is_quantized =
87
+ embedding_matrix.quant_type() != QuantizationType::NONE;
88
+
89
+ const FeatureVector &feature_vector = feature_vectors[es_index];
90
+ const int num_features = feature_vector.size();
91
+ for (int fi = 0; fi < num_features; ++fi) {
92
+ const FeatureType *feature_type = feature_vector.type(fi);
93
+ int feature_offset = concat_offset + feature_type->base() * embedding_dim;
94
+ CLD3_DCHECK(feature_offset + embedding_dim <=
95
+ static_cast<int>(concat->size()));
96
+
97
+ // Weighted embeddings will be added starting from this address.
98
+ float *concat_ptr = concat->data() + feature_offset;
99
+
100
+ // Pointer to float / uint8 weights for relevant embedding.
101
+ const void *embedding_data;
102
+
103
+ // Multiplier for each embedding weight.
104
+ float multiplier;
105
+ const FeatureValue feature_value = feature_vector.value(fi);
106
+ if (feature_type->is_continuous()) {
107
+ // Continuous features (encoded as FloatFeatureValue).
108
+ FloatFeatureValue float_feature_value(feature_value);
109
+ const int id = float_feature_value.value.id;
110
+ embedding_matrix.get_embedding(id, &embedding_data, &multiplier);
111
+ multiplier *= float_feature_value.value.weight;
112
+ } else {
113
+ // Discrete features: every present feature has implicit value 1.0.
114
+ embedding_matrix.get_embedding(feature_value, &embedding_data,
115
+ &multiplier);
116
+ }
117
+
118
+ if (is_quantized) {
119
+ const uint8 *quant_weights =
120
+ reinterpret_cast<const uint8 *>(embedding_data);
121
+ for (int i = 0; i < embedding_dim; ++i, ++quant_weights, ++concat_ptr) {
122
+ // 128 is bias for UINT8 quantization, only one we currently support.
123
+ *concat_ptr += (static_cast<int>(*quant_weights) - 128) * multiplier;
124
+ }
125
+ } else {
126
+ const float *weights = reinterpret_cast<const float *>(embedding_data);
127
+ for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) {
128
+ *concat_ptr += *weights * multiplier;
129
+ }
130
+ }
131
+ }
132
+ }
133
+ }
134
+
135
+ template <typename ScaleAdderClass>
136
+ void EmbeddingNetwork::FinishComputeFinalScores(const Vector &concat,
137
+ Vector *scores) const {
138
+ Vector h0(hidden_bias_[0].size());
139
+ SparseReluProductPlusBias<ScaleAdderClass>(false, hidden_weights_[0],
140
+ hidden_bias_[0], concat, &h0);
141
+
142
+ CLD3_DCHECK((hidden_weights_.size() == 1) || (hidden_weights_.size() == 2));
143
+ if (hidden_weights_.size() == 1) { // 1 hidden layer
144
+ SparseReluProductPlusBias<ScaleAdderClass>(true, softmax_weights_,
145
+ softmax_bias_, h0, scores);
146
+ } else if (hidden_weights_.size() == 2) { // 2 hidden layers
147
+ Vector h1(hidden_bias_[1].size());
148
+ SparseReluProductPlusBias<ScaleAdderClass>(true, hidden_weights_[1],
149
+ hidden_bias_[1], h0, &h1);
150
+ SparseReluProductPlusBias<ScaleAdderClass>(true, softmax_weights_,
151
+ softmax_bias_, h1, scores);
152
+ }
153
+ }
154
+
155
+ void EmbeddingNetwork::ComputeFinalScores(
156
+ const std::vector<FeatureVector> &features, Vector *scores) const {
157
+ Vector concat;
158
+ ConcatEmbeddings(features, &concat);
159
+
160
+ scores->resize(softmax_bias_.size());
161
+ FinishComputeFinalScores<SimpleAdder>(concat, scores);
162
+ }
163
+
164
+ EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model)
165
+ : model_(model) {
166
+ int offset_sum = 0;
167
+ for (int i = 0; i < model_->embedding_dim_size(); ++i) {
168
+ CLD3_DCHECK(offset_sum == model_->concat_offset(i));
169
+ offset_sum += model_->embedding_dim(i) * model_->embedding_num_features(i);
170
+ embedding_matrices_.emplace_back(model_->GetEmbeddingMatrix(i));
171
+ }
172
+
173
+ CLD3_DCHECK(model_->hidden_size() == model_->hidden_bias_size());
174
+ hidden_weights_.resize(model_->hidden_size());
175
+ hidden_bias_.resize(model_->hidden_size());
176
+ for (int i = 0; i < model_->hidden_size(); ++i) {
177
+ FillMatrixParams(model_->GetHiddenLayerMatrix(i), &hidden_weights_[i]);
178
+ EmbeddingNetworkParams::Matrix bias = model_->GetHiddenLayerBias(i);
179
+ CLD3_DCHECK(1 == bias.cols);
180
+ CheckNoQuantization(bias);
181
+ hidden_bias_[i] = VectorWrapper(
182
+ reinterpret_cast<const float *>(bias.elements), bias.rows);
183
+ }
184
+
185
+ CLD3_DCHECK(model_->HasSoftmax());
186
+ FillMatrixParams(model_->GetSoftmaxMatrix(), &softmax_weights_);
187
+
188
+ EmbeddingNetworkParams::Matrix softmax_bias = model_->GetSoftmaxBias();
189
+ CLD3_DCHECK(1 == softmax_bias.cols);
190
+ CheckNoQuantization(softmax_bias);
191
+ softmax_bias_ =
192
+ VectorWrapper(reinterpret_cast<const float *>(softmax_bias.elements),
193
+ softmax_bias.rows);
194
+ }
195
+
196
+ } // namespace chrome_lang_id
@@ -0,0 +1,186 @@
1
+ /* Copyright 2016 Google Inc. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef EMBEDDING_NETWORK_H_
17
+ #define EMBEDDING_NETWORK_H_
18
+
19
+ #include <vector>
20
+
21
+ #include "embedding_network_params.h"
22
+ #include "feature_extractor.h"
23
+ #include "float16.h"
24
+
25
+ namespace chrome_lang_id {
26
+
27
+ // Classifier using a hand-coded feed-forward neural network.
28
+ //
29
+ // No gradient computation, just inference.
30
+ //
31
+ // Based on the more general nlp_saft::EmbeddingNetwork.
32
+ //
33
+ // Classification works as follows:
34
+ //
35
+ // Discrete features -> Embeddings -> Concatenation -> Hidden+ -> Softmax
36
+ //
37
+ // In words: given some discrete features, this class extracts the embeddings
38
+ // for these features, concatenates them, passes them through one or two hidden
39
+ // layers (each layer uses Relu) and next through a softmax layer that computes
40
+ // an unnormalized score for each possible class. Note: there is always a
41
+ // softmax layer.
42
+ //
43
+ // NOTE(salcianu): current code can easily be changed to allow more than two
44
+ // hidden layers. Feel free to do so if you have a genuine need for that.
45
+ class EmbeddingNetwork {
46
+ public:
47
+ // Class used to represent an embedding matrix. Each row is the embedding on
48
+ // a vocabulary element. Number of columns = number of embedding dimensions.
49
+ class EmbeddingMatrix {
50
+ public:
51
+ explicit EmbeddingMatrix(const EmbeddingNetworkParams::Matrix source_matrix)
52
+ : rows_(source_matrix.rows),
53
+ cols_(source_matrix.cols),
54
+ quant_type_(source_matrix.quant_type),
55
+ data_(source_matrix.elements),
56
+ row_size_in_bytes_(GetRowSizeInBytes(cols_, quant_type_)),
57
+ quant_scales_(source_matrix.quant_scales) {}
58
+
59
+ // Returns vocabulary size; one embedding for each vocabulary element.
60
+ int size() const { return rows_; }
61
+
62
+ // Returns number of weights in embedding of each vocabulary element.
63
+ int dim() const { return cols_; }
64
+
65
+ // Returns quantization type for this embedding matrix.
66
+ QuantizationType quant_type() const { return quant_type_; }
67
+
68
+ // Gets embedding for k-th vocabulary element: on return, sets *data to
69
+ // point to the embedding weights and *scale to the quantization scale (1.0
70
+ // if no quantization).
71
+ void get_embedding(int k, const void **data, float *scale) const {
72
+ CLD3_CHECK(k >= 0);
73
+ CLD3_CHECK(k < size());
74
+ *data = reinterpret_cast<const char *>(data_) + k * row_size_in_bytes_;
75
+ if (quant_type_ == QuantizationType::NONE) {
76
+ *scale = 1.0;
77
+ } else {
78
+ *scale = Float16To32(quant_scales_[k]);
79
+ }
80
+ }
81
+
82
+ private:
83
+ static int GetRowSizeInBytes(int cols, QuantizationType quant_type) {
84
+ CLD3_DCHECK((quant_type == QuantizationType::NONE) ||
85
+ (quant_type == QuantizationType::UINT8));
86
+ if (quant_type == QuantizationType::NONE) {
87
+ return cols * sizeof(float);
88
+ } else { // QuantizationType::UINT8
89
+ return cols * sizeof(uint8);
90
+ }
91
+ }
92
+
93
+ // Vocabulary size.
94
+ int rows_;
95
+
96
+ // Number of elements in each embedding.
97
+ int cols_;
98
+
99
+ QuantizationType quant_type_;
100
+
101
+ // Pointer to the embedding weights, in row-major order. This is a pointer
102
+ // to an array of floats / uint8, depending on the quantization type.
103
+ // Not owned.
104
+ const void *data_;
105
+
106
+ // Number of bytes for one row. Used to jump to next row in data_.
107
+ int row_size_in_bytes_;
108
+
109
+ // Pointer to quantization scales. nullptr if no quantization. Otherwise,
110
+ // quant_scales_[i] is scale for embedding of i-th vocabulary element.
111
+ const float16 *quant_scales_;
112
+ };
113
+
114
+ // An immutable vector that doesn't own the memory that stores the underlying
115
+ // floats. Can be used e.g., as a wrapper around model weights stored in the
116
+ // static memory.
117
+ class VectorWrapper {
118
+ public:
119
+ VectorWrapper() : VectorWrapper(nullptr, 0) {}
120
+
121
+ // Constructs a vector wrapper around the size consecutive floats that start
122
+ // at address data. Note: the underlying data should be alive for at least
123
+ // the lifetime of this VectorWrapper object. That's trivially true if data
124
+ // points to statically allocated data :)
125
+ VectorWrapper(const float *data, int size) : data_(data), size_(size) {}
126
+
127
+ int size() const { return size_; }
128
+
129
+ const float *data() const { return data_; }
130
+
131
+ private:
132
+ const float *data_; // Not owned.
133
+ int size_;
134
+
135
+ // Doesn't own anything, so it can be copied and assigned at will :)
136
+ };
137
+
138
+ typedef std::vector<VectorWrapper> Matrix;
139
+ typedef std::vector<float> Vector;
140
+
141
+ // Constructs an embedding network using the parameters from model.
142
+ //
143
+ // Note: model should stay alive for at least the lifetime of this
144
+ // EmbeddingNetwork object. TODO(salcianu): remove this constraint: we should
145
+ // copy all necessary data (except, of course, the static weights) at
146
+ // construction time and use that, instead of relying on model.
147
+ explicit EmbeddingNetwork(const EmbeddingNetworkParams *model);
148
+
149
+ virtual ~EmbeddingNetwork() {}
150
+
151
+ // Runs forward computation to fill scores with unnormalized output unit
152
+ // scores. This is useful for making predictions.
153
+ void ComputeFinalScores(const std::vector<FeatureVector> &features,
154
+ Vector *scores) const;
155
+
156
+ private:
157
+ // Computes the softmax scores (prior to normalization) from the concatenated
158
+ // representation.
159
+ template <typename ScaleAdderClass>
160
+ void FinishComputeFinalScores(const Vector &concat, Vector *scores) const;
161
+
162
+ // Constructs the concatenated input embedding vector in place in output
163
+ // vector concat.
164
+ void ConcatEmbeddings(const std::vector<FeatureVector> &features,
165
+ Vector *concat) const;
166
+
167
+ // Pointer to the model object passed to the constructor. Not owned.
168
+ const EmbeddingNetworkParams *model_;
169
+
170
+ // Network parameters.
171
+
172
+ // One weight matrix for each embedding.
173
+ std::vector<EmbeddingMatrix> embedding_matrices_;
174
+
175
+ // One weight matrix and one vector of bias weights for each hiden layer.
176
+ std::vector<Matrix> hidden_weights_;
177
+ std::vector<VectorWrapper> hidden_bias_;
178
+
179
+ // Weight matrix and bias vector for the softmax layer.
180
+ Matrix softmax_weights_;
181
+ VectorWrapper softmax_bias_;
182
+ };
183
+
184
+ } // namespace chrome_lang_id
185
+
186
+ #endif // EMBEDDING_NETWORK_H_
@@ -0,0 +1,285 @@
1
+ /* Copyright 2016 Google Inc. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef EMBEDDING_NETWORK_PARAMS_H_
17
+ #define EMBEDDING_NETWORK_PARAMS_H_
18
+
19
+ #include <string>
20
+
21
+ #include "base.h"
22
+ #include "float16.h"
23
+
24
+ namespace chrome_lang_id {
25
+
26
+ enum class QuantizationType { NONE = 0, UINT8 };
27
+
28
+ // API for accessing parameters from a statically-linked EmbeddingNetworkProto.
29
+ class EmbeddingNetworkParams {
30
+ public:
31
+ virtual ~EmbeddingNetworkParams() {}
32
+
33
+ // **** High-level API.
34
+
35
+ // Simple representation of a matrix. This small struct that doesn't own any
36
+ // resource intentionally supports copy / assign, to simplify our APIs.
37
+ struct Matrix {
38
+ // Number of rows.
39
+ int rows;
40
+
41
+ // Number of columns.
42
+ int cols;
43
+
44
+ QuantizationType quant_type;
45
+
46
+ // Pointer to matrix elements, in row-major order
47
+ // (https://en.wikipedia.org/wiki/Row-major_order) Not owned.
48
+ const void *elements;
49
+
50
+ // Quantization scales: one scale for each row.
51
+ const float16 *quant_scales;
52
+ };
53
+
54
+ // Returns i-th embedding matrix. Crashes on out of bounds indices.
55
+ //
56
+ // This is the transpose of the corresponding matrix from the original proto.
57
+ Matrix GetEmbeddingMatrix(int i) const {
58
+ CheckMatrixRange(i, embeddings_size(), "embedding matrix");
59
+ Matrix matrix;
60
+ matrix.rows = embeddings_num_rows(i);
61
+ matrix.cols = embeddings_num_cols(i);
62
+ matrix.elements = embeddings_weights(i);
63
+ matrix.quant_type = embeddings_quant_type(i);
64
+ matrix.quant_scales = embeddings_quant_scales(i);
65
+ return matrix;
66
+ }
67
+
68
+ // Returns weight matrix for i-th hidden layer. Crashes on out of bounds
69
+ // indices.
70
+ //
71
+ // This is the transpose of the corresponding matrix from the original proto.
72
+ Matrix GetHiddenLayerMatrix(int i) const {
73
+ CheckMatrixRange(i, hidden_size(), "hidden layer");
74
+ Matrix matrix;
75
+ matrix.rows = hidden_num_rows(i);
76
+ matrix.cols = hidden_num_cols(i);
77
+
78
+ // Quantization not supported here.
79
+ matrix.quant_type = QuantizationType::NONE;
80
+ matrix.elements = hidden_weights(i);
81
+ return matrix;
82
+ }
83
+
84
+ // Returns bias for i-th hidden layer. Technically a Matrix, but we expect it
85
+ // to be a row/column vector (i.e., num rows or num cols is 1). However, we
86
+ // don't CHECK for that: we just provide access to underlying data. Crashes
87
+ // on out of bounds indices.
88
+ Matrix GetHiddenLayerBias(int i) const {
89
+ CheckMatrixRange(i, hidden_bias_size(), "hidden layer bias");
90
+ Matrix matrix;
91
+ matrix.rows = hidden_bias_num_rows(i);
92
+ matrix.cols = hidden_bias_num_cols(i);
93
+
94
+ // Quantization not supported here.
95
+ matrix.quant_type = QuantizationType::NONE;
96
+ matrix.elements = hidden_bias_weights(i);
97
+ return matrix;
98
+ }
99
+
100
+ // Returns true if a softmax layer exists.
101
+ bool HasSoftmax() const { return softmax_size() == 1; }
102
+
103
+ // Returns weight matrix for the softmax layer. Note: should be called only
104
+ // if HasSoftmax() is true.
105
+ //
106
+ // This is the transpose of the corresponding matrix from the original proto.
107
+ Matrix GetSoftmaxMatrix() const {
108
+ CLD3_DCHECK(HasSoftmax());
109
+ Matrix matrix;
110
+ matrix.rows = softmax_num_rows(0);
111
+ matrix.cols = softmax_num_cols(0);
112
+
113
+ // Quantization not supported here.
114
+ matrix.quant_type = QuantizationType::NONE;
115
+ matrix.elements = softmax_weights(0);
116
+ return matrix;
117
+ }
118
+
119
+ // Returns bias for the softmax layer. Technically a Matrix, but we expect it
120
+ // to be a row/column vector (i.e., num rows or num cols is 1). However, we
121
+ // don't CHECK for that: we just provide access to underlying data.
122
+ Matrix GetSoftmaxBias() const {
123
+ CLD3_DCHECK(HasSoftmax());
124
+ Matrix matrix;
125
+ matrix.rows = softmax_bias_num_rows(0);
126
+ matrix.cols = softmax_bias_num_cols(0);
127
+
128
+ // Quantization not supported here.
129
+ matrix.quant_type = QuantizationType::NONE;
130
+ matrix.elements = softmax_bias_weights(0);
131
+ return matrix;
132
+ }
133
+
134
+ // **** Low-level API.
135
+ //
136
+ // * Most low-level API methods are documented by giving an equivalent
137
+ // function call on proto, the original proto (of type
138
+ // EmbeddingNetworkProto) which was used to generate the C++ code.
139
+ //
140
+ // * To simplify our generation code, optional proto fields of message type
141
+ // are treated as repeated fields with 0 or 1 instances. As such, we have
142
+ // *_size() methods for such optional fields: they return 0 or 1.
143
+ //
144
+ // * "transpose(M)" denotes the transpose of a matrix M.
145
+
146
+ // ** Access methods for repeated MatrixParams embeddings.
147
+ //
148
+ // Returns proto.embeddings_size().
149
+ virtual int embeddings_size() const = 0;
150
+
151
+ // Returns number of rows of transpose(proto.embeddings(i)).
152
+ virtual int embeddings_num_rows(int i) const = 0;
153
+
154
+ // Returns number of columns of transpose(proto.embeddings(i)).
155
+ virtual int embeddings_num_cols(int i) const = 0;
156
+
157
+ // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major
158
+ // order.
159
+ virtual const void *embeddings_weights(int i) const = 0;
160
+
161
+ virtual QuantizationType embeddings_quant_type(int i) const {
162
+ return QuantizationType::NONE;
163
+ }
164
+
165
+ virtual const float16 *embeddings_quant_scales(int i) const {
166
+ return nullptr;
167
+ }
168
+
169
+ // ** Access methods for repeated MatrixParams hidden.
170
+ //
171
+ // Returns embedding_network_proto.hidden_size().
172
+ virtual int hidden_size() const = 0;
173
+
174
+ // Returns embedding_network_proto.hidden(i).rows().
175
+ virtual int hidden_num_rows(int i) const = 0;
176
+
177
+ // Returns embedding_network_proto.hidden(i).rows().
178
+ virtual int hidden_num_cols(int i) const = 0;
179
+
180
+ // Returns pointer to beginning of array of floats with all values from
181
+ // embedding_network_proto.hidden(i).
182
+ virtual const void *hidden_weights(int i) const = 0;
183
+
184
+ // ** Access methods for repeated MatrixParams hidden_bias.
185
+ //
186
+ // Returns proto.hidden_bias_size().
187
+ virtual int hidden_bias_size() const = 0;
188
+
189
+ // Returns number of rows of proto.hidden_bias(i).
190
+ virtual int hidden_bias_num_rows(int i) const = 0;
191
+
192
+ // Returns number of columns of proto.hidden_bias(i).
193
+ virtual int hidden_bias_num_cols(int i) const = 0;
194
+
195
+ // Returns pointer to elements of proto.hidden_bias(i), in row-major order.
196
+ virtual const void *hidden_bias_weights(int i) const = 0;
197
+
198
+ // ** Access methods for optional MatrixParams softmax.
199
+ //
200
+ // Returns 1 if proto has optional field softmax, 0 otherwise.
201
+ virtual int softmax_size() const = 0;
202
+
203
+ // Returns number of rows of transpose(proto.softmax()).
204
+ virtual int softmax_num_rows(int i) const = 0;
205
+
206
+ // Returns number of columns of transpose(proto.softmax()).
207
+ virtual int softmax_num_cols(int i) const = 0;
208
+
209
+ // Returns pointer to elements of transpose(proto.softmax()), in row-major
210
+ // order.
211
+ virtual const void *softmax_weights(int i) const = 0;
212
+
213
+ // ** Access methods for optional MatrixParams softmax_bias.
214
+ //
215
+ // Returns 1 if proto has optional field softmax_bias, 0 otherwise.
216
+ virtual int softmax_bias_size() const = 0;
217
+
218
+ // Returns number of rows of proto.softmax_bias().
219
+ virtual int softmax_bias_num_rows(int i) const = 0;
220
+
221
+ // Returns number of columns of proto.softmax_bias().
222
+ virtual int softmax_bias_num_cols(int i) const = 0;
223
+
224
+ // Returns pointer to elements of proto.softmax_bias(), in row-major order.
225
+ virtual const void *softmax_bias_weights(int i) const = 0;
226
+
227
+ // ** Access methods for repeated int32 embedding_dim.
228
+ //
229
+ // Returns proto.embedding_dim_size().
230
+ virtual int embedding_dim_size() const = 0;
231
+
232
+ // Returns proto.embedding_dim(i).
233
+ virtual int embedding_dim(int i) const = 0;
234
+
235
+ // ** Access methods for repeated int32 embedding_num_features.
236
+ //
237
+ // Returns proto.embedding_num_features_size().
238
+ virtual int embedding_num_features_size() const = 0;
239
+
240
+ // Returns proto.embedding_num_features(i).
241
+ virtual int embedding_num_features(int i) const = 0;
242
+
243
+ // ** Access methods for repeated int32 embedding_features_domain_size.
244
+ //
245
+ // Returns proto.embedding_features_domain_size_size().
246
+ virtual int embedding_features_domain_size_size() const = 0;
247
+
248
+ // Returns proto.embedding_features_domain_size(i).
249
+ virtual int embedding_features_domain_size(int i) const = 0;
250
+
251
+ // ** Access methods for repeated int32 concat_offset.
252
+ //
253
+ // Returns proto.concat_offset_size().
254
+ virtual int concat_offset(int i) const = 0;
255
+
256
+ // Returns proto.concat_offset(i).
257
+ virtual int concat_offset_size() const = 0;
258
+
259
+ // ** Access methods for concat_layer_size.
260
+ //
261
+ // Returns proto.has_concat_layer_size().
262
+ virtual bool has_concat_layer_size() const = 0;
263
+
264
+ // Returns proto.concat_layer_size().
265
+ virtual int concat_layer_size() const = 0;
266
+
267
+ // ** Access methods for is_precomputed
268
+ //
269
+ // Returns proto.has_is_precomputed().
270
+ virtual bool has_is_precomputed() const = 0;
271
+
272
+ // Returns proto.is_precomputed().
273
+ virtual bool is_precomputed() const = 0;
274
+
275
+ private:
276
+ void CheckMatrixRange(int index, int num_matrices,
277
+ const string &description) const {
278
+ CLD3_DCHECK(index >= 0);
279
+ CLD3_DCHECK(index < num_matrices);
280
+ }
281
+ }; // class EmbeddingNetworkParams
282
+
283
+ } // namespace chrome_lang_id
284
+
285
+ #endif // EMBEDDING_NETWORK_PARAMS_H_