cld3 3.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (72) hide show
  1. checksums.yaml +7 -0
  2. data/Gemfile +18 -0
  3. data/LICENSE +204 -0
  4. data/LICENSE_CLD3 +203 -0
  5. data/README.md +22 -0
  6. data/cld3.gemspec +35 -0
  7. data/ext/cld3/base.cc +36 -0
  8. data/ext/cld3/base.h +106 -0
  9. data/ext/cld3/casts.h +98 -0
  10. data/ext/cld3/embedding_feature_extractor.cc +51 -0
  11. data/ext/cld3/embedding_feature_extractor.h +182 -0
  12. data/ext/cld3/embedding_network.cc +196 -0
  13. data/ext/cld3/embedding_network.h +186 -0
  14. data/ext/cld3/embedding_network_params.h +285 -0
  15. data/ext/cld3/extconf.rb +49 -0
  16. data/ext/cld3/feature_extractor.cc +137 -0
  17. data/ext/cld3/feature_extractor.h +633 -0
  18. data/ext/cld3/feature_extractor.proto +50 -0
  19. data/ext/cld3/feature_types.cc +72 -0
  20. data/ext/cld3/feature_types.h +158 -0
  21. data/ext/cld3/fixunicodevalue.cc +55 -0
  22. data/ext/cld3/fixunicodevalue.h +69 -0
  23. data/ext/cld3/float16.h +58 -0
  24. data/ext/cld3/fml_parser.cc +308 -0
  25. data/ext/cld3/fml_parser.h +123 -0
  26. data/ext/cld3/generated_entities.cc +296 -0
  27. data/ext/cld3/generated_ulscript.cc +678 -0
  28. data/ext/cld3/generated_ulscript.h +142 -0
  29. data/ext/cld3/getonescriptspan.cc +1109 -0
  30. data/ext/cld3/getonescriptspan.h +124 -0
  31. data/ext/cld3/integral_types.h +37 -0
  32. data/ext/cld3/lang_id_nn_params.cc +57449 -0
  33. data/ext/cld3/lang_id_nn_params.h +178 -0
  34. data/ext/cld3/language_identifier_features.cc +165 -0
  35. data/ext/cld3/language_identifier_features.h +116 -0
  36. data/ext/cld3/nnet_language_identifier.cc +380 -0
  37. data/ext/cld3/nnet_language_identifier.h +175 -0
  38. data/ext/cld3/nnet_language_identifier_c.cc +72 -0
  39. data/ext/cld3/offsetmap.cc +478 -0
  40. data/ext/cld3/offsetmap.h +168 -0
  41. data/ext/cld3/port.h +143 -0
  42. data/ext/cld3/registry.cc +28 -0
  43. data/ext/cld3/registry.h +242 -0
  44. data/ext/cld3/relevant_script_feature.cc +89 -0
  45. data/ext/cld3/relevant_script_feature.h +49 -0
  46. data/ext/cld3/script_detector.h +156 -0
  47. data/ext/cld3/sentence.proto +77 -0
  48. data/ext/cld3/sentence_features.cc +29 -0
  49. data/ext/cld3/sentence_features.h +35 -0
  50. data/ext/cld3/simple_adder.h +72 -0
  51. data/ext/cld3/stringpiece.h +81 -0
  52. data/ext/cld3/task_context.cc +161 -0
  53. data/ext/cld3/task_context.h +81 -0
  54. data/ext/cld3/task_context_params.cc +74 -0
  55. data/ext/cld3/task_context_params.h +54 -0
  56. data/ext/cld3/task_spec.proto +98 -0
  57. data/ext/cld3/text_processing.cc +245 -0
  58. data/ext/cld3/text_processing.h +30 -0
  59. data/ext/cld3/unicodetext.cc +96 -0
  60. data/ext/cld3/unicodetext.h +144 -0
  61. data/ext/cld3/utf8acceptinterchange.h +486 -0
  62. data/ext/cld3/utf8prop_lettermarkscriptnum.h +1631 -0
  63. data/ext/cld3/utf8repl_lettermarklower.h +758 -0
  64. data/ext/cld3/utf8scannot_lettermarkspecial.h +1455 -0
  65. data/ext/cld3/utf8statetable.cc +1344 -0
  66. data/ext/cld3/utf8statetable.h +285 -0
  67. data/ext/cld3/utils.cc +241 -0
  68. data/ext/cld3/utils.h +144 -0
  69. data/ext/cld3/workspace.cc +64 -0
  70. data/ext/cld3/workspace.h +177 -0
  71. data/lib/cld3.rb +99 -0
  72. metadata +158 -0
@@ -0,0 +1,380 @@
1
+ /* Copyright 2016 Google Inc. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "nnet_language_identifier.h"
17
+
18
+ #include <math.h>
19
+
20
+ #include <algorithm>
21
+ #include <limits>
22
+ #include <string>
23
+
24
+ #include "base.h"
25
+ #include "embedding_network.h"
26
+ #include "registry.h"
27
+ #include "relevant_script_feature.h"
28
+ #include "script_span/generated_ulscript.h"
29
+ #include "script_span/getonescriptspan.h"
30
+ #include "script_span/text_processing.h"
31
+ #include "cld_3/protos/sentence.pb.h"
32
+ #include "sentence_features.h"
33
+ #include "task_context.h"
34
+ #include "workspace.h"
35
+
36
+ namespace chrome_lang_id {
37
+ namespace {
38
+
39
+ // Struct for accumulating stats for a language as text subsequences of the same
40
+ // script are processed.
41
+ struct LangChunksStats {
42
+ // Sum of probabilities across subsequences.
43
+ float prob_sum = 0.0;
44
+
45
+ // Total number of bytes corresponding to the language.
46
+ int byte_sum = 0;
47
+
48
+ // Number chunks corresponding to the language.
49
+ int num_chunks = 0;
50
+ };
51
+
52
+ // Compares two pairs based on their values.
53
+ bool OrderBySecondDescending(const std::pair<string, float> &x,
54
+ const std::pair<string, float> &y) {
55
+ if (x.second == y.second) {
56
+ return x.first < y.first;
57
+ } else {
58
+ return x.second > y.second;
59
+ }
60
+ }
61
+
62
+ // Returns "true" if the languge prediction is reliable based on the
63
+ // probability, and "false" otherwise.
64
+ bool ResultIsReliable(const string &language, float probability) {
65
+ if (language == "hr" || language == "bs") {
66
+ return (probability >= NNetLanguageIdentifier::kReliabilityHrBsThreshold);
67
+ } else {
68
+ return (probability >= NNetLanguageIdentifier::kReliabilityThreshold);
69
+ }
70
+ }
71
+
72
+ // Finds the number of interchange-valid bytes to process.
73
+ int FindNumValidBytesToProcess(const string &text) {
74
+ // Check if the size of the input text can fit into an int. If not, focus on
75
+ // the first std::numeric_limits<int>::max() bytes.
76
+ const int doc_text_size =
77
+ (text.size() < static_cast<size_t>(std::numeric_limits<int>::max()))
78
+ ? static_cast<int>(text.size())
79
+ : std::numeric_limits<int>::max();
80
+
81
+ // Truncate the input text if it is too long and find the span containing
82
+ // interchange-valid UTF8.
83
+ const int num_valid_bytes = CLD2::SpanInterchangeValid(
84
+ text.c_str(),
85
+ std::min(NNetLanguageIdentifier::kMaxNumInputBytesToConsider,
86
+ doc_text_size));
87
+
88
+ return num_valid_bytes;
89
+ }
90
+ } // namespace
91
+
92
+ const int NNetLanguageIdentifier::kMinNumBytesToConsider = 140;
93
+ const int NNetLanguageIdentifier::kMaxNumBytesToConsider = 700;
94
+ const int NNetLanguageIdentifier::kMaxNumInputBytesToConsider = 10000;
95
+ const int NNetLanguageIdentifier::kNumSnippets = 5;
96
+ const char NNetLanguageIdentifier::kUnknown[] = "und";
97
+ const float NNetLanguageIdentifier::kReliabilityThreshold = 0.7f;
98
+ const float NNetLanguageIdentifier::kReliabilityHrBsThreshold = 0.5f;
99
+
100
+ const string LanguageIdEmbeddingFeatureExtractor::ArgPrefix() const {
101
+ return "language_identifier";
102
+ }
103
+
104
+ NNetLanguageIdentifier::NNetLanguageIdentifier()
105
+ : NNetLanguageIdentifier(kMinNumBytesToConsider, kMaxNumBytesToConsider) {}
106
+
107
+ static WholeSentenceFeature *cbog_factory() {
108
+ return new ContinuousBagOfNgramsFunction;
109
+ }
110
+
111
+ static WholeSentenceFeature *rsf_factory() { return new RelevantScriptFeature; }
112
+
113
+ static WholeSentenceFeature *sf_factory() { return new ScriptFeature; }
114
+
115
+ NNetLanguageIdentifier::NNetLanguageIdentifier(int min_num_bytes,
116
+ int max_num_bytes)
117
+ : num_languages_(TaskContextParams::GetNumLanguages()),
118
+ network_(&nn_params_),
119
+ min_num_bytes_(min_num_bytes),
120
+ max_num_bytes_(max_num_bytes) {
121
+ CLD3_CHECK(max_num_bytes_ > 0);
122
+ CLD3_CHECK(min_num_bytes_ >= 0);
123
+ CLD3_CHECK(min_num_bytes_ < max_num_bytes_);
124
+
125
+ num_snippets_ = (max_num_bytes_ <= kNumSnippets) ? 1 : kNumSnippets;
126
+ snippet_size_ = max_num_bytes_ / num_snippets_;
127
+
128
+ if (WholeSentenceFeature::registry() == nullptr) {
129
+ // Create registry for our WholeSentenceFeature(s).
130
+ RegisterableClass<WholeSentenceFeature>::CreateRegistry(
131
+ "sentence feature function", "WholeSentenceFeature", __FILE__,
132
+ __LINE__);
133
+ }
134
+
135
+ // Register our WholeSentenceFeature(s).
136
+ // Register ContinuousBagOfNgramsFunction feature function.
137
+ static WholeSentenceFeature::Registry::Registrar cbog_registrar(
138
+ WholeSentenceFeature::registry(), "continuous-bag-of-ngrams",
139
+ "ContinuousBagOfNgramsFunction", __FILE__, __LINE__, cbog_factory);
140
+
141
+ // Register RelevantScriptFeature feature function.
142
+ static WholeSentenceFeature::Registry::Registrar rsf_registrar(
143
+ WholeSentenceFeature::registry(), "continuous-bag-of-relevant-scripts",
144
+ "RelevantScriptFeature", __FILE__, __LINE__, rsf_factory);
145
+
146
+ // Register ScriptFeature feature function.
147
+ static WholeSentenceFeature::Registry::Registrar sf_registrar(
148
+ WholeSentenceFeature::registry(), "script", "ScriptFeature", __FILE__,
149
+ __LINE__, sf_factory);
150
+
151
+ // Get the model parameters, set up and initialize the model.
152
+ TaskContext context;
153
+ TaskContextParams::ToTaskContext(&context);
154
+ Setup(&context);
155
+ Init(&context);
156
+ }
157
+
158
+ NNetLanguageIdentifier::~NNetLanguageIdentifier() {}
159
+
160
+ void NNetLanguageIdentifier::Setup(TaskContext *context) {
161
+ feature_extractor_.Setup(context);
162
+ }
163
+
164
+ void NNetLanguageIdentifier::Init(TaskContext *context) {
165
+ feature_extractor_.Init(context);
166
+ feature_extractor_.RequestWorkspaces(&workspace_registry_);
167
+ }
168
+
169
+ void NNetLanguageIdentifier::GetFeatures(
170
+ Sentence *sentence, std::vector<FeatureVector> *features) const {
171
+ // Feature workspace set.
172
+ WorkspaceSet workspace;
173
+ workspace.Reset(workspace_registry_);
174
+ feature_extractor_.Preprocess(&workspace, sentence);
175
+ feature_extractor_.ExtractFeatures(workspace, *sentence, features);
176
+ }
177
+
178
+ // Returns the language name corresponding to the given id.
179
+ string NNetLanguageIdentifier::GetLanguageName(int language_id) const {
180
+ CLD3_CHECK(language_id >= 0);
181
+ CLD3_CHECK(language_id < num_languages_);
182
+ return TaskContextParams::language_names(language_id);
183
+ }
184
+
185
+ NNetLanguageIdentifier::Result NNetLanguageIdentifier::FindLanguage(
186
+ const string &text) {
187
+ const int num_valid_bytes = FindNumValidBytesToProcess(text);
188
+
189
+ // Iterate over the input with ScriptScanner to clean up the text (e.g.,
190
+ // removing digits, punctuation, brackets).
191
+ // TODO(abakalov): Extract the code that does the clean-up out of
192
+ // ScriptScanner.
193
+ CLD2::ScriptScanner ss(text.c_str(), num_valid_bytes, /*is_plain_text=*/true);
194
+ CLD2::LangSpan script_span;
195
+ string cleaned;
196
+ while (ss.GetOneScriptSpanLower(&script_span)) {
197
+ // script_span has spaces at the beginning and the end, so there is no need
198
+ // for a delimiter.
199
+ cleaned.append(script_span.text, script_span.text_bytes);
200
+ }
201
+
202
+ if (static_cast<int>(cleaned.size()) < min_num_bytes_) {
203
+ return Result();
204
+ }
205
+
206
+ // Copy to a vector because a non-const char* will be needed.
207
+ std::vector<char> text_to_process;
208
+ for (size_t i = 0; i < cleaned.size(); ++i) {
209
+ text_to_process.push_back(cleaned[i]);
210
+ }
211
+ text_to_process.push_back('\0');
212
+
213
+ // Remove repetitive chunks or ones containing mostly spaces.
214
+ const int chunk_size = 0; // Use the default.
215
+ char *text_begin = &text_to_process[0];
216
+ const int new_length = CLD2::CheapSqueezeInplace(
217
+ text_begin, text_to_process.size() - 1, chunk_size);
218
+ if (new_length < min_num_bytes_) {
219
+ return Result();
220
+ }
221
+
222
+ const string squeezed_text_to_process =
223
+ SelectTextGivenBeginAndSize(text_begin, new_length);
224
+ return FindLanguageOfValidUTF8(squeezed_text_to_process);
225
+ }
226
+
227
+ NNetLanguageIdentifier::Result NNetLanguageIdentifier::FindLanguageOfValidUTF8(
228
+ const string &text) {
229
+ // Create a Sentence storing the input text.
230
+ Sentence sentence;
231
+ sentence.set_text(text);
232
+
233
+ // Predict language.
234
+ // TODO(salcianu): reuse vector<FeatureVector>.
235
+ std::vector<FeatureVector> features(feature_extractor_.NumEmbeddings());
236
+ GetFeatures(&sentence, &features);
237
+
238
+ EmbeddingNetwork::Vector scores;
239
+ network_.ComputeFinalScores(features, &scores);
240
+ int prediction_id = -1;
241
+ float max_val = -std::numeric_limits<float>::infinity();
242
+ for (size_t i = 0; i < scores.size(); ++i) {
243
+ if (scores[i] > max_val) {
244
+ prediction_id = i;
245
+ max_val = scores[i];
246
+ }
247
+ }
248
+
249
+ // Compute probability.
250
+ Result result;
251
+ float diff_sum = 0.0;
252
+ for (size_t i = 0; i < scores.size(); ++i) {
253
+ diff_sum += exp(scores[i] - max_val);
254
+ }
255
+ const float log_sum_exp = max_val + log(diff_sum);
256
+ result.probability = exp(max_val - log_sum_exp);
257
+
258
+ result.language = GetLanguageName(prediction_id);
259
+ result.is_reliable = ResultIsReliable(result.language, result.probability);
260
+ result.proportion = 1.0;
261
+ return result;
262
+ }
263
+
264
+ std::vector<NNetLanguageIdentifier::Result>
265
+ NNetLanguageIdentifier::FindTopNMostFreqLangs(const string &text,
266
+ int num_langs) {
267
+ std::vector<Result> results;
268
+
269
+ // Truncate the input text if it is too long and find the span containing
270
+ // interchange-valid UTF8.
271
+ const int num_valid_bytes = FindNumValidBytesToProcess(text);
272
+ if (num_valid_bytes == 0) {
273
+ while (num_langs-- > 0) {
274
+ results.emplace_back();
275
+ }
276
+ return results;
277
+ }
278
+
279
+ // Process each subsequence of the same script.
280
+ CLD2::ScriptScanner ss(text.c_str(), num_valid_bytes, /*is_plain_text=*/true);
281
+ CLD2::LangSpan script_span;
282
+ std::unordered_map<string, LangChunksStats> lang_stats;
283
+ int total_num_bytes = 0;
284
+ Result result;
285
+ string language;
286
+ int chunk_size = 0; // Use the default.
287
+ while (ss.GetOneScriptSpanLower(&script_span)) {
288
+ const int num_original_span_bytes = script_span.text_bytes;
289
+
290
+ // Remove repetitive chunks or ones containing mostly spaces.
291
+ const int new_length = CLD2::CheapSqueezeInplace(
292
+ script_span.text, script_span.text_bytes, chunk_size);
293
+ script_span.text_bytes = new_length;
294
+
295
+ if (script_span.text_bytes < min_num_bytes_) {
296
+ continue;
297
+ }
298
+ total_num_bytes += num_original_span_bytes;
299
+
300
+ const string selected_text = SelectTextGivenScriptSpan(script_span);
301
+ result = FindLanguageOfValidUTF8(selected_text);
302
+ language = result.language;
303
+ lang_stats[language].byte_sum += num_original_span_bytes;
304
+ lang_stats[language].prob_sum +=
305
+ result.probability * num_original_span_bytes;
306
+ lang_stats[language].num_chunks++;
307
+ }
308
+
309
+ // Sort the languages based on the number of bytes associated with them.
310
+ // TODO(abakalov): Consider alternative possibly more efficient portable
311
+ // approaches for finding the top N languages. Given that on average, there
312
+ // aren't that many languages in the input, it's likely that the benefits will
313
+ // be negligible (if any).
314
+ std::vector<std::pair<string, float>> langs_and_byte_counts;
315
+ for (const auto &entry : lang_stats) {
316
+ langs_and_byte_counts.emplace_back(entry.first, entry.second.byte_sum);
317
+ }
318
+ std::sort(langs_and_byte_counts.begin(), langs_and_byte_counts.end(),
319
+ OrderBySecondDescending);
320
+
321
+ const float byte_sum = static_cast<float>(total_num_bytes);
322
+ const int num_langs_to_save =
323
+ std::min(num_langs, static_cast<int>(langs_and_byte_counts.size()));
324
+ for (int indx = 0; indx < num_langs_to_save; ++indx) {
325
+ Result result;
326
+ const string &language = langs_and_byte_counts.at(indx).first;
327
+ const LangChunksStats &stats = lang_stats.at(language);
328
+ result.language = language;
329
+ result.probability = stats.prob_sum / stats.byte_sum;
330
+ result.proportion = stats.byte_sum / byte_sum;
331
+ result.is_reliable = ResultIsReliable(language, result.probability);
332
+ results.push_back(result);
333
+ }
334
+
335
+ int padding_size = num_langs - langs_and_byte_counts.size();
336
+ while (padding_size-- > 0) {
337
+ results.emplace_back();
338
+ }
339
+ return results;
340
+ }
341
+
342
+ string NNetLanguageIdentifier::SelectTextGivenScriptSpan(
343
+ const CLD2::LangSpan &script_span) {
344
+ return SelectTextGivenBeginAndSize(script_span.text, script_span.text_bytes);
345
+ }
346
+
347
+ string NNetLanguageIdentifier::SelectTextGivenBeginAndSize(
348
+ const char *text_begin, int text_size) {
349
+ string output_text;
350
+
351
+ // If the size of the input is greater than the maxium number of bytes needed
352
+ // for a prediction, then concatenate snippets that are equally spread out
353
+ // throughout the input.
354
+ if (text_size > max_num_bytes_) {
355
+ const char *snippet_begin = nullptr;
356
+ const char *snippet_end = text_begin;
357
+
358
+ // Number of bytes between the snippets.
359
+ const int num_skip_bytes =
360
+ (text_size - max_num_bytes_) / (num_snippets_ + 1);
361
+
362
+ for (int i = 0; i < num_snippets_; ++i) {
363
+ // Using SpanInterchangeValid to find the offsets to ensure that we are
364
+ // not splitting a character in two.
365
+ const int actual_num_skip_bytes =
366
+ CLD2::SpanInterchangeValid(snippet_end, num_skip_bytes);
367
+ snippet_begin = snippet_end + actual_num_skip_bytes;
368
+ const int actual_snippet_size =
369
+ CLD2::SpanInterchangeValid(snippet_begin, snippet_size_);
370
+ snippet_end = snippet_begin + actual_snippet_size;
371
+ output_text.append(snippet_begin, actual_snippet_size);
372
+ output_text.append(" ");
373
+ }
374
+ } else {
375
+ output_text.append(text_begin, text_size);
376
+ }
377
+ return output_text;
378
+ }
379
+
380
+ } // namespace chrome_lang_id
@@ -0,0 +1,175 @@
1
+ /* Copyright 2016 Google Inc. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef NNET_LANGUAGE_IDENTIFIER_H_
17
+ #define NNET_LANGUAGE_IDENTIFIER_H_
18
+
19
+ #include <string>
20
+
21
+ #include "base.h"
22
+ #include "embedding_feature_extractor.h"
23
+ #include "embedding_network.h"
24
+ #include "lang_id_nn_params.h"
25
+ #include "language_identifier_features.h"
26
+ #include "script_span/getonescriptspan.h"
27
+ #include "cld_3/protos/sentence.pb.h"
28
+ #include "sentence_features.h"
29
+ #include "task_context.h"
30
+ #include "task_context_params.h"
31
+ #include "cld_3/protos/task_spec.pb.h"
32
+ #include "workspace.h"
33
+
34
+ namespace chrome_lang_id {
35
+
36
+ // Specialization of the EmbeddingFeatureExtractor for extracting from
37
+ // (Sentence, int).
38
+ class LanguageIdEmbeddingFeatureExtractor
39
+ : public EmbeddingFeatureExtractor<WholeSentenceExtractor, Sentence> {
40
+ public:
41
+ const string ArgPrefix() const override;
42
+ };
43
+
44
+ // Class for detecting the language of a document.
45
+ class NNetLanguageIdentifier {
46
+ public:
47
+ // Information about a predicted language.
48
+ struct Result {
49
+ string language = kUnknown;
50
+ float probability = 0.0; // Language probability.
51
+ bool is_reliable = false; // Whether the prediction is reliable.
52
+
53
+ // Proportion of bytes associated with the language. If FindLanguage is
54
+ // called, this variable is set to 1.
55
+ float proportion = 0.0;
56
+ };
57
+
58
+ NNetLanguageIdentifier();
59
+ NNetLanguageIdentifier(int min_num_bytes, int max_num_bytes);
60
+ ~NNetLanguageIdentifier();
61
+
62
+ // Finds the most likely language for the given text, along with additional
63
+ // information (e.g., probability). The prediction is based on the first N
64
+ // bytes where N is the minumum between the number of interchange valid UTF8
65
+ // bytes and max_num_bytes_. If N is less than min_num_bytes_ long, then this
66
+ // function returns kUnknown.
67
+ Result FindLanguage(const string &text);
68
+
69
+ // Splits the input text (up to the first byte, if any, that is not
70
+ // interchange valid UTF8) into spans based on the script, predicts a language
71
+ // for each span, and returns a vector storing the top num_langs most frequent
72
+ // languages along with additional information (e.g., proportions). The number
73
+ // of bytes considered for each span is the minimum between the size of the
74
+ // span and max_num_bytes_. If more languages are requested than what is
75
+ // available in the input, then for those cases kUnknown is returned. Also, if
76
+ // the size of the span is less than min_num_bytes_ long, then the span is
77
+ // skipped. If the input text is too long, only the first
78
+ // kMaxNumInputBytesToConsider bytes are processed.
79
+ std::vector<Result> FindTopNMostFreqLangs(const string &text, int num_langs);
80
+
81
+ // String returned when a language is unknown or prediction cannot be made.
82
+ static const char kUnknown[];
83
+
84
+ // Min number of bytes needed to make a prediction if the default constructor
85
+ // is called.
86
+ static const int kMinNumBytesToConsider;
87
+
88
+ // Max number of bytes to consider to make a prediction if the default
89
+ // constructor is called.
90
+ static const int kMaxNumBytesToConsider;
91
+
92
+ // Max number of input bytes to process.
93
+ static const int kMaxNumInputBytesToConsider;
94
+
95
+ // Predictions with probability greater than or equal to this threshold are
96
+ // marked as reliable. This threshold was optimized on a set of text segments
97
+ // extracted from wikipedia, and results in an overall precision, recall,
98
+ // and f1 equal to 0.9760, 0.9624, and 0.9692, respectively.
99
+ static const float kReliabilityThreshold;
100
+
101
+ // Reliability threshold for the languages hr and bs.
102
+ static const float kReliabilityHrBsThreshold;
103
+
104
+ private:
105
+ // Sets up and initializes the model.
106
+ void Setup(TaskContext *context);
107
+ void Init(TaskContext *context);
108
+
109
+ // Extract features from sentence. On return, FeatureVector features[i]
110
+ // contains the features for the embedding space #i.
111
+ void GetFeatures(Sentence *sentence,
112
+ std::vector<FeatureVector> *features) const;
113
+
114
+ // Finds the most likely language for the given text. Assumes that the text is
115
+ // interchange valid UTF8.
116
+ Result FindLanguageOfValidUTF8(const string &text);
117
+
118
+ // Returns the language name corresponding to the given id.
119
+ string GetLanguageName(int language_id) const;
120
+
121
+ // Concatenates snippets of text equally spread out throughout the input if
122
+ // the size of the input is greater than the maximum number of bytes needed to
123
+ // make a prediction. The resulting string is used for language
124
+ // identification.
125
+ string SelectTextGivenScriptSpan(const CLD2::LangSpan &script_span);
126
+ string SelectTextGivenBeginAndSize(const char *text_begin, int text_size);
127
+
128
+ // Number of languages.
129
+ const int num_languages_;
130
+
131
+ // Typed feature extractor for embeddings.
132
+ LanguageIdEmbeddingFeatureExtractor feature_extractor_;
133
+
134
+ // The registry of shared workspaces in the feature extractor.
135
+ WorkspaceRegistry workspace_registry_;
136
+
137
+ // Parameters for the neural networks.
138
+ LangIdNNParams nn_params_;
139
+
140
+ // Neural network to use for scoring.
141
+ EmbeddingNetwork network_;
142
+
143
+ // This feature function is not relevant to this class. Adding this variable
144
+ // ensures that the features are linked.
145
+ ContinuousBagOfNgramsFunction ngram_function_;
146
+
147
+ // Minimum number of bytes needed to make a prediction. If the default
148
+ // constructor is called, this variable is equal to kMinNumBytesToConsider.
149
+ int min_num_bytes_;
150
+
151
+ // Maximum number of bytes to use to make a prediction. If the default
152
+ // constructor is called, this variable is equal to kMaxNumBytesToConsider.
153
+ int max_num_bytes_;
154
+
155
+ // Number of snippets to concatenate to produce the string used for language
156
+ // identification. If max_num_bytes_ <= kNumSnippets (i.e., the maximum number
157
+ // of bytes needed to make a prediction is smaller or equal to the number of
158
+ // default snippets), then this variable is equal to 1. Otherwise, it is set
159
+ // to kNumSnippets.
160
+ int num_snippets_;
161
+
162
+ // The string used to make a prediction is created by concatenating
163
+ // num_snippets_ snippets of size snippet_size_ = (max_num_bytes_ /
164
+ // num_snippets_) that are equaly spread out throughout the input.
165
+ int snippet_size_;
166
+
167
+ // Default number of snippets to concatenate to produce the string used for
168
+ // language identification. For the actual number of snippets, see
169
+ // num_snippets_.
170
+ static const int kNumSnippets;
171
+ };
172
+
173
+ } // namespace chrome_lang_id
174
+
175
+ #endif // NNET_LANGUAGE_IDENTIFIER_H_