cld3 3.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. checksums.yaml +7 -0
  2. data/Gemfile +18 -0
  3. data/LICENSE +204 -0
  4. data/LICENSE_CLD3 +203 -0
  5. data/README.md +22 -0
  6. data/cld3.gemspec +35 -0
  7. data/ext/cld3/base.cc +36 -0
  8. data/ext/cld3/base.h +106 -0
  9. data/ext/cld3/casts.h +98 -0
  10. data/ext/cld3/embedding_feature_extractor.cc +51 -0
  11. data/ext/cld3/embedding_feature_extractor.h +182 -0
  12. data/ext/cld3/embedding_network.cc +196 -0
  13. data/ext/cld3/embedding_network.h +186 -0
  14. data/ext/cld3/embedding_network_params.h +285 -0
  15. data/ext/cld3/extconf.rb +49 -0
  16. data/ext/cld3/feature_extractor.cc +137 -0
  17. data/ext/cld3/feature_extractor.h +633 -0
  18. data/ext/cld3/feature_extractor.proto +50 -0
  19. data/ext/cld3/feature_types.cc +72 -0
  20. data/ext/cld3/feature_types.h +158 -0
  21. data/ext/cld3/fixunicodevalue.cc +55 -0
  22. data/ext/cld3/fixunicodevalue.h +69 -0
  23. data/ext/cld3/float16.h +58 -0
  24. data/ext/cld3/fml_parser.cc +308 -0
  25. data/ext/cld3/fml_parser.h +123 -0
  26. data/ext/cld3/generated_entities.cc +296 -0
  27. data/ext/cld3/generated_ulscript.cc +678 -0
  28. data/ext/cld3/generated_ulscript.h +142 -0
  29. data/ext/cld3/getonescriptspan.cc +1109 -0
  30. data/ext/cld3/getonescriptspan.h +124 -0
  31. data/ext/cld3/integral_types.h +37 -0
  32. data/ext/cld3/lang_id_nn_params.cc +57449 -0
  33. data/ext/cld3/lang_id_nn_params.h +178 -0
  34. data/ext/cld3/language_identifier_features.cc +165 -0
  35. data/ext/cld3/language_identifier_features.h +116 -0
  36. data/ext/cld3/nnet_language_identifier.cc +380 -0
  37. data/ext/cld3/nnet_language_identifier.h +175 -0
  38. data/ext/cld3/nnet_language_identifier_c.cc +72 -0
  39. data/ext/cld3/offsetmap.cc +478 -0
  40. data/ext/cld3/offsetmap.h +168 -0
  41. data/ext/cld3/port.h +143 -0
  42. data/ext/cld3/registry.cc +28 -0
  43. data/ext/cld3/registry.h +242 -0
  44. data/ext/cld3/relevant_script_feature.cc +89 -0
  45. data/ext/cld3/relevant_script_feature.h +49 -0
  46. data/ext/cld3/script_detector.h +156 -0
  47. data/ext/cld3/sentence.proto +77 -0
  48. data/ext/cld3/sentence_features.cc +29 -0
  49. data/ext/cld3/sentence_features.h +35 -0
  50. data/ext/cld3/simple_adder.h +72 -0
  51. data/ext/cld3/stringpiece.h +81 -0
  52. data/ext/cld3/task_context.cc +161 -0
  53. data/ext/cld3/task_context.h +81 -0
  54. data/ext/cld3/task_context_params.cc +74 -0
  55. data/ext/cld3/task_context_params.h +54 -0
  56. data/ext/cld3/task_spec.proto +98 -0
  57. data/ext/cld3/text_processing.cc +245 -0
  58. data/ext/cld3/text_processing.h +30 -0
  59. data/ext/cld3/unicodetext.cc +96 -0
  60. data/ext/cld3/unicodetext.h +144 -0
  61. data/ext/cld3/utf8acceptinterchange.h +486 -0
  62. data/ext/cld3/utf8prop_lettermarkscriptnum.h +1631 -0
  63. data/ext/cld3/utf8repl_lettermarklower.h +758 -0
  64. data/ext/cld3/utf8scannot_lettermarkspecial.h +1455 -0
  65. data/ext/cld3/utf8statetable.cc +1344 -0
  66. data/ext/cld3/utf8statetable.h +285 -0
  67. data/ext/cld3/utils.cc +241 -0
  68. data/ext/cld3/utils.h +144 -0
  69. data/ext/cld3/workspace.cc +64 -0
  70. data/ext/cld3/workspace.h +177 -0
  71. data/lib/cld3.rb +99 -0
  72. metadata +158 -0
@@ -0,0 +1,380 @@
1
+ /* Copyright 2016 Google Inc. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "nnet_language_identifier.h"
17
+
18
+ #include <math.h>
19
+
20
+ #include <algorithm>
21
+ #include <limits>
22
+ #include <string>
23
+
24
+ #include "base.h"
25
+ #include "embedding_network.h"
26
+ #include "registry.h"
27
+ #include "relevant_script_feature.h"
28
+ #include "script_span/generated_ulscript.h"
29
+ #include "script_span/getonescriptspan.h"
30
+ #include "script_span/text_processing.h"
31
+ #include "cld_3/protos/sentence.pb.h"
32
+ #include "sentence_features.h"
33
+ #include "task_context.h"
34
+ #include "workspace.h"
35
+
36
+ namespace chrome_lang_id {
37
+ namespace {
38
+
39
+ // Struct for accumulating stats for a language as text subsequences of the same
40
+ // script are processed.
41
+ struct LangChunksStats {
42
+ // Sum of probabilities across subsequences.
43
+ float prob_sum = 0.0;
44
+
45
+ // Total number of bytes corresponding to the language.
46
+ int byte_sum = 0;
47
+
48
+ // Number chunks corresponding to the language.
49
+ int num_chunks = 0;
50
+ };
51
+
52
+ // Compares two pairs based on their values.
53
+ bool OrderBySecondDescending(const std::pair<string, float> &x,
54
+ const std::pair<string, float> &y) {
55
+ if (x.second == y.second) {
56
+ return x.first < y.first;
57
+ } else {
58
+ return x.second > y.second;
59
+ }
60
+ }
61
+
62
+ // Returns "true" if the languge prediction is reliable based on the
63
+ // probability, and "false" otherwise.
64
+ bool ResultIsReliable(const string &language, float probability) {
65
+ if (language == "hr" || language == "bs") {
66
+ return (probability >= NNetLanguageIdentifier::kReliabilityHrBsThreshold);
67
+ } else {
68
+ return (probability >= NNetLanguageIdentifier::kReliabilityThreshold);
69
+ }
70
+ }
71
+
72
+ // Finds the number of interchange-valid bytes to process.
73
+ int FindNumValidBytesToProcess(const string &text) {
74
+ // Check if the size of the input text can fit into an int. If not, focus on
75
+ // the first std::numeric_limits<int>::max() bytes.
76
+ const int doc_text_size =
77
+ (text.size() < static_cast<size_t>(std::numeric_limits<int>::max()))
78
+ ? static_cast<int>(text.size())
79
+ : std::numeric_limits<int>::max();
80
+
81
+ // Truncate the input text if it is too long and find the span containing
82
+ // interchange-valid UTF8.
83
+ const int num_valid_bytes = CLD2::SpanInterchangeValid(
84
+ text.c_str(),
85
+ std::min(NNetLanguageIdentifier::kMaxNumInputBytesToConsider,
86
+ doc_text_size));
87
+
88
+ return num_valid_bytes;
89
+ }
90
+ } // namespace
91
+
92
+ const int NNetLanguageIdentifier::kMinNumBytesToConsider = 140;
93
+ const int NNetLanguageIdentifier::kMaxNumBytesToConsider = 700;
94
+ const int NNetLanguageIdentifier::kMaxNumInputBytesToConsider = 10000;
95
+ const int NNetLanguageIdentifier::kNumSnippets = 5;
96
+ const char NNetLanguageIdentifier::kUnknown[] = "und";
97
+ const float NNetLanguageIdentifier::kReliabilityThreshold = 0.7f;
98
+ const float NNetLanguageIdentifier::kReliabilityHrBsThreshold = 0.5f;
99
+
100
+ const string LanguageIdEmbeddingFeatureExtractor::ArgPrefix() const {
101
+ return "language_identifier";
102
+ }
103
+
104
+ NNetLanguageIdentifier::NNetLanguageIdentifier()
105
+ : NNetLanguageIdentifier(kMinNumBytesToConsider, kMaxNumBytesToConsider) {}
106
+
107
+ static WholeSentenceFeature *cbog_factory() {
108
+ return new ContinuousBagOfNgramsFunction;
109
+ }
110
+
111
+ static WholeSentenceFeature *rsf_factory() { return new RelevantScriptFeature; }
112
+
113
+ static WholeSentenceFeature *sf_factory() { return new ScriptFeature; }
114
+
115
+ NNetLanguageIdentifier::NNetLanguageIdentifier(int min_num_bytes,
116
+ int max_num_bytes)
117
+ : num_languages_(TaskContextParams::GetNumLanguages()),
118
+ network_(&nn_params_),
119
+ min_num_bytes_(min_num_bytes),
120
+ max_num_bytes_(max_num_bytes) {
121
+ CLD3_CHECK(max_num_bytes_ > 0);
122
+ CLD3_CHECK(min_num_bytes_ >= 0);
123
+ CLD3_CHECK(min_num_bytes_ < max_num_bytes_);
124
+
125
+ num_snippets_ = (max_num_bytes_ <= kNumSnippets) ? 1 : kNumSnippets;
126
+ snippet_size_ = max_num_bytes_ / num_snippets_;
127
+
128
+ if (WholeSentenceFeature::registry() == nullptr) {
129
+ // Create registry for our WholeSentenceFeature(s).
130
+ RegisterableClass<WholeSentenceFeature>::CreateRegistry(
131
+ "sentence feature function", "WholeSentenceFeature", __FILE__,
132
+ __LINE__);
133
+ }
134
+
135
+ // Register our WholeSentenceFeature(s).
136
+ // Register ContinuousBagOfNgramsFunction feature function.
137
+ static WholeSentenceFeature::Registry::Registrar cbog_registrar(
138
+ WholeSentenceFeature::registry(), "continuous-bag-of-ngrams",
139
+ "ContinuousBagOfNgramsFunction", __FILE__, __LINE__, cbog_factory);
140
+
141
+ // Register RelevantScriptFeature feature function.
142
+ static WholeSentenceFeature::Registry::Registrar rsf_registrar(
143
+ WholeSentenceFeature::registry(), "continuous-bag-of-relevant-scripts",
144
+ "RelevantScriptFeature", __FILE__, __LINE__, rsf_factory);
145
+
146
+ // Register ScriptFeature feature function.
147
+ static WholeSentenceFeature::Registry::Registrar sf_registrar(
148
+ WholeSentenceFeature::registry(), "script", "ScriptFeature", __FILE__,
149
+ __LINE__, sf_factory);
150
+
151
+ // Get the model parameters, set up and initialize the model.
152
+ TaskContext context;
153
+ TaskContextParams::ToTaskContext(&context);
154
+ Setup(&context);
155
+ Init(&context);
156
+ }
157
+
158
+ NNetLanguageIdentifier::~NNetLanguageIdentifier() {}
159
+
160
+ void NNetLanguageIdentifier::Setup(TaskContext *context) {
161
+ feature_extractor_.Setup(context);
162
+ }
163
+
164
+ void NNetLanguageIdentifier::Init(TaskContext *context) {
165
+ feature_extractor_.Init(context);
166
+ feature_extractor_.RequestWorkspaces(&workspace_registry_);
167
+ }
168
+
169
+ void NNetLanguageIdentifier::GetFeatures(
170
+ Sentence *sentence, std::vector<FeatureVector> *features) const {
171
+ // Feature workspace set.
172
+ WorkspaceSet workspace;
173
+ workspace.Reset(workspace_registry_);
174
+ feature_extractor_.Preprocess(&workspace, sentence);
175
+ feature_extractor_.ExtractFeatures(workspace, *sentence, features);
176
+ }
177
+
178
+ // Returns the language name corresponding to the given id.
179
+ string NNetLanguageIdentifier::GetLanguageName(int language_id) const {
180
+ CLD3_CHECK(language_id >= 0);
181
+ CLD3_CHECK(language_id < num_languages_);
182
+ return TaskContextParams::language_names(language_id);
183
+ }
184
+
185
+ NNetLanguageIdentifier::Result NNetLanguageIdentifier::FindLanguage(
186
+ const string &text) {
187
+ const int num_valid_bytes = FindNumValidBytesToProcess(text);
188
+
189
+ // Iterate over the input with ScriptScanner to clean up the text (e.g.,
190
+ // removing digits, punctuation, brackets).
191
+ // TODO(abakalov): Extract the code that does the clean-up out of
192
+ // ScriptScanner.
193
+ CLD2::ScriptScanner ss(text.c_str(), num_valid_bytes, /*is_plain_text=*/true);
194
+ CLD2::LangSpan script_span;
195
+ string cleaned;
196
+ while (ss.GetOneScriptSpanLower(&script_span)) {
197
+ // script_span has spaces at the beginning and the end, so there is no need
198
+ // for a delimiter.
199
+ cleaned.append(script_span.text, script_span.text_bytes);
200
+ }
201
+
202
+ if (static_cast<int>(cleaned.size()) < min_num_bytes_) {
203
+ return Result();
204
+ }
205
+
206
+ // Copy to a vector because a non-const char* will be needed.
207
+ std::vector<char> text_to_process;
208
+ for (size_t i = 0; i < cleaned.size(); ++i) {
209
+ text_to_process.push_back(cleaned[i]);
210
+ }
211
+ text_to_process.push_back('\0');
212
+
213
+ // Remove repetitive chunks or ones containing mostly spaces.
214
+ const int chunk_size = 0; // Use the default.
215
+ char *text_begin = &text_to_process[0];
216
+ const int new_length = CLD2::CheapSqueezeInplace(
217
+ text_begin, text_to_process.size() - 1, chunk_size);
218
+ if (new_length < min_num_bytes_) {
219
+ return Result();
220
+ }
221
+
222
+ const string squeezed_text_to_process =
223
+ SelectTextGivenBeginAndSize(text_begin, new_length);
224
+ return FindLanguageOfValidUTF8(squeezed_text_to_process);
225
+ }
226
+
227
+ NNetLanguageIdentifier::Result NNetLanguageIdentifier::FindLanguageOfValidUTF8(
228
+ const string &text) {
229
+ // Create a Sentence storing the input text.
230
+ Sentence sentence;
231
+ sentence.set_text(text);
232
+
233
+ // Predict language.
234
+ // TODO(salcianu): reuse vector<FeatureVector>.
235
+ std::vector<FeatureVector> features(feature_extractor_.NumEmbeddings());
236
+ GetFeatures(&sentence, &features);
237
+
238
+ EmbeddingNetwork::Vector scores;
239
+ network_.ComputeFinalScores(features, &scores);
240
+ int prediction_id = -1;
241
+ float max_val = -std::numeric_limits<float>::infinity();
242
+ for (size_t i = 0; i < scores.size(); ++i) {
243
+ if (scores[i] > max_val) {
244
+ prediction_id = i;
245
+ max_val = scores[i];
246
+ }
247
+ }
248
+
249
+ // Compute probability.
250
+ Result result;
251
+ float diff_sum = 0.0;
252
+ for (size_t i = 0; i < scores.size(); ++i) {
253
+ diff_sum += exp(scores[i] - max_val);
254
+ }
255
+ const float log_sum_exp = max_val + log(diff_sum);
256
+ result.probability = exp(max_val - log_sum_exp);
257
+
258
+ result.language = GetLanguageName(prediction_id);
259
+ result.is_reliable = ResultIsReliable(result.language, result.probability);
260
+ result.proportion = 1.0;
261
+ return result;
262
+ }
263
+
264
+ std::vector<NNetLanguageIdentifier::Result>
265
+ NNetLanguageIdentifier::FindTopNMostFreqLangs(const string &text,
266
+ int num_langs) {
267
+ std::vector<Result> results;
268
+
269
+ // Truncate the input text if it is too long and find the span containing
270
+ // interchange-valid UTF8.
271
+ const int num_valid_bytes = FindNumValidBytesToProcess(text);
272
+ if (num_valid_bytes == 0) {
273
+ while (num_langs-- > 0) {
274
+ results.emplace_back();
275
+ }
276
+ return results;
277
+ }
278
+
279
+ // Process each subsequence of the same script.
280
+ CLD2::ScriptScanner ss(text.c_str(), num_valid_bytes, /*is_plain_text=*/true);
281
+ CLD2::LangSpan script_span;
282
+ std::unordered_map<string, LangChunksStats> lang_stats;
283
+ int total_num_bytes = 0;
284
+ Result result;
285
+ string language;
286
+ int chunk_size = 0; // Use the default.
287
+ while (ss.GetOneScriptSpanLower(&script_span)) {
288
+ const int num_original_span_bytes = script_span.text_bytes;
289
+
290
+ // Remove repetitive chunks or ones containing mostly spaces.
291
+ const int new_length = CLD2::CheapSqueezeInplace(
292
+ script_span.text, script_span.text_bytes, chunk_size);
293
+ script_span.text_bytes = new_length;
294
+
295
+ if (script_span.text_bytes < min_num_bytes_) {
296
+ continue;
297
+ }
298
+ total_num_bytes += num_original_span_bytes;
299
+
300
+ const string selected_text = SelectTextGivenScriptSpan(script_span);
301
+ result = FindLanguageOfValidUTF8(selected_text);
302
+ language = result.language;
303
+ lang_stats[language].byte_sum += num_original_span_bytes;
304
+ lang_stats[language].prob_sum +=
305
+ result.probability * num_original_span_bytes;
306
+ lang_stats[language].num_chunks++;
307
+ }
308
+
309
+ // Sort the languages based on the number of bytes associated with them.
310
+ // TODO(abakalov): Consider alternative possibly more efficient portable
311
+ // approaches for finding the top N languages. Given that on average, there
312
+ // aren't that many languages in the input, it's likely that the benefits will
313
+ // be negligible (if any).
314
+ std::vector<std::pair<string, float>> langs_and_byte_counts;
315
+ for (const auto &entry : lang_stats) {
316
+ langs_and_byte_counts.emplace_back(entry.first, entry.second.byte_sum);
317
+ }
318
+ std::sort(langs_and_byte_counts.begin(), langs_and_byte_counts.end(),
319
+ OrderBySecondDescending);
320
+
321
+ const float byte_sum = static_cast<float>(total_num_bytes);
322
+ const int num_langs_to_save =
323
+ std::min(num_langs, static_cast<int>(langs_and_byte_counts.size()));
324
+ for (int indx = 0; indx < num_langs_to_save; ++indx) {
325
+ Result result;
326
+ const string &language = langs_and_byte_counts.at(indx).first;
327
+ const LangChunksStats &stats = lang_stats.at(language);
328
+ result.language = language;
329
+ result.probability = stats.prob_sum / stats.byte_sum;
330
+ result.proportion = stats.byte_sum / byte_sum;
331
+ result.is_reliable = ResultIsReliable(language, result.probability);
332
+ results.push_back(result);
333
+ }
334
+
335
+ int padding_size = num_langs - langs_and_byte_counts.size();
336
+ while (padding_size-- > 0) {
337
+ results.emplace_back();
338
+ }
339
+ return results;
340
+ }
341
+
342
+ string NNetLanguageIdentifier::SelectTextGivenScriptSpan(
343
+ const CLD2::LangSpan &script_span) {
344
+ return SelectTextGivenBeginAndSize(script_span.text, script_span.text_bytes);
345
+ }
346
+
347
+ string NNetLanguageIdentifier::SelectTextGivenBeginAndSize(
348
+ const char *text_begin, int text_size) {
349
+ string output_text;
350
+
351
+ // If the size of the input is greater than the maxium number of bytes needed
352
+ // for a prediction, then concatenate snippets that are equally spread out
353
+ // throughout the input.
354
+ if (text_size > max_num_bytes_) {
355
+ const char *snippet_begin = nullptr;
356
+ const char *snippet_end = text_begin;
357
+
358
+ // Number of bytes between the snippets.
359
+ const int num_skip_bytes =
360
+ (text_size - max_num_bytes_) / (num_snippets_ + 1);
361
+
362
+ for (int i = 0; i < num_snippets_; ++i) {
363
+ // Using SpanInterchangeValid to find the offsets to ensure that we are
364
+ // not splitting a character in two.
365
+ const int actual_num_skip_bytes =
366
+ CLD2::SpanInterchangeValid(snippet_end, num_skip_bytes);
367
+ snippet_begin = snippet_end + actual_num_skip_bytes;
368
+ const int actual_snippet_size =
369
+ CLD2::SpanInterchangeValid(snippet_begin, snippet_size_);
370
+ snippet_end = snippet_begin + actual_snippet_size;
371
+ output_text.append(snippet_begin, actual_snippet_size);
372
+ output_text.append(" ");
373
+ }
374
+ } else {
375
+ output_text.append(text_begin, text_size);
376
+ }
377
+ return output_text;
378
+ }
379
+
380
+ } // namespace chrome_lang_id
@@ -0,0 +1,175 @@
1
+ /* Copyright 2016 Google Inc. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef NNET_LANGUAGE_IDENTIFIER_H_
17
+ #define NNET_LANGUAGE_IDENTIFIER_H_
18
+
19
+ #include <string>
20
+
21
+ #include "base.h"
22
+ #include "embedding_feature_extractor.h"
23
+ #include "embedding_network.h"
24
+ #include "lang_id_nn_params.h"
25
+ #include "language_identifier_features.h"
26
+ #include "script_span/getonescriptspan.h"
27
+ #include "cld_3/protos/sentence.pb.h"
28
+ #include "sentence_features.h"
29
+ #include "task_context.h"
30
+ #include "task_context_params.h"
31
+ #include "cld_3/protos/task_spec.pb.h"
32
+ #include "workspace.h"
33
+
34
+ namespace chrome_lang_id {
35
+
36
+ // Specialization of the EmbeddingFeatureExtractor for extracting from
37
+ // (Sentence, int).
38
+ class LanguageIdEmbeddingFeatureExtractor
39
+ : public EmbeddingFeatureExtractor<WholeSentenceExtractor, Sentence> {
40
+ public:
41
+ const string ArgPrefix() const override;
42
+ };
43
+
44
+ // Class for detecting the language of a document.
45
+ class NNetLanguageIdentifier {
46
+ public:
47
+ // Information about a predicted language.
48
+ struct Result {
49
+ string language = kUnknown;
50
+ float probability = 0.0; // Language probability.
51
+ bool is_reliable = false; // Whether the prediction is reliable.
52
+
53
+ // Proportion of bytes associated with the language. If FindLanguage is
54
+ // called, this variable is set to 1.
55
+ float proportion = 0.0;
56
+ };
57
+
58
+ NNetLanguageIdentifier();
59
+ NNetLanguageIdentifier(int min_num_bytes, int max_num_bytes);
60
+ ~NNetLanguageIdentifier();
61
+
62
+ // Finds the most likely language for the given text, along with additional
63
+ // information (e.g., probability). The prediction is based on the first N
64
+ // bytes where N is the minumum between the number of interchange valid UTF8
65
+ // bytes and max_num_bytes_. If N is less than min_num_bytes_ long, then this
66
+ // function returns kUnknown.
67
+ Result FindLanguage(const string &text);
68
+
69
+ // Splits the input text (up to the first byte, if any, that is not
70
+ // interchange valid UTF8) into spans based on the script, predicts a language
71
+ // for each span, and returns a vector storing the top num_langs most frequent
72
+ // languages along with additional information (e.g., proportions). The number
73
+ // of bytes considered for each span is the minimum between the size of the
74
+ // span and max_num_bytes_. If more languages are requested than what is
75
+ // available in the input, then for those cases kUnknown is returned. Also, if
76
+ // the size of the span is less than min_num_bytes_ long, then the span is
77
+ // skipped. If the input text is too long, only the first
78
+ // kMaxNumInputBytesToConsider bytes are processed.
79
+ std::vector<Result> FindTopNMostFreqLangs(const string &text, int num_langs);
80
+
81
+ // String returned when a language is unknown or prediction cannot be made.
82
+ static const char kUnknown[];
83
+
84
+ // Min number of bytes needed to make a prediction if the default constructor
85
+ // is called.
86
+ static const int kMinNumBytesToConsider;
87
+
88
+ // Max number of bytes to consider to make a prediction if the default
89
+ // constructor is called.
90
+ static const int kMaxNumBytesToConsider;
91
+
92
+ // Max number of input bytes to process.
93
+ static const int kMaxNumInputBytesToConsider;
94
+
95
+ // Predictions with probability greater than or equal to this threshold are
96
+ // marked as reliable. This threshold was optimized on a set of text segments
97
+ // extracted from wikipedia, and results in an overall precision, recall,
98
+ // and f1 equal to 0.9760, 0.9624, and 0.9692, respectively.
99
+ static const float kReliabilityThreshold;
100
+
101
+ // Reliability threshold for the languages hr and bs.
102
+ static const float kReliabilityHrBsThreshold;
103
+
104
+ private:
105
+ // Sets up and initializes the model.
106
+ void Setup(TaskContext *context);
107
+ void Init(TaskContext *context);
108
+
109
+ // Extract features from sentence. On return, FeatureVector features[i]
110
+ // contains the features for the embedding space #i.
111
+ void GetFeatures(Sentence *sentence,
112
+ std::vector<FeatureVector> *features) const;
113
+
114
+ // Finds the most likely language for the given text. Assumes that the text is
115
+ // interchange valid UTF8.
116
+ Result FindLanguageOfValidUTF8(const string &text);
117
+
118
+ // Returns the language name corresponding to the given id.
119
+ string GetLanguageName(int language_id) const;
120
+
121
+ // Concatenates snippets of text equally spread out throughout the input if
122
+ // the size of the input is greater than the maximum number of bytes needed to
123
+ // make a prediction. The resulting string is used for language
124
+ // identification.
125
+ string SelectTextGivenScriptSpan(const CLD2::LangSpan &script_span);
126
+ string SelectTextGivenBeginAndSize(const char *text_begin, int text_size);
127
+
128
+ // Number of languages.
129
+ const int num_languages_;
130
+
131
+ // Typed feature extractor for embeddings.
132
+ LanguageIdEmbeddingFeatureExtractor feature_extractor_;
133
+
134
+ // The registry of shared workspaces in the feature extractor.
135
+ WorkspaceRegistry workspace_registry_;
136
+
137
+ // Parameters for the neural networks.
138
+ LangIdNNParams nn_params_;
139
+
140
+ // Neural network to use for scoring.
141
+ EmbeddingNetwork network_;
142
+
143
+ // This feature function is not relevant to this class. Adding this variable
144
+ // ensures that the features are linked.
145
+ ContinuousBagOfNgramsFunction ngram_function_;
146
+
147
+ // Minimum number of bytes needed to make a prediction. If the default
148
+ // constructor is called, this variable is equal to kMinNumBytesToConsider.
149
+ int min_num_bytes_;
150
+
151
+ // Maximum number of bytes to use to make a prediction. If the default
152
+ // constructor is called, this variable is equal to kMaxNumBytesToConsider.
153
+ int max_num_bytes_;
154
+
155
+ // Number of snippets to concatenate to produce the string used for language
156
+ // identification. If max_num_bytes_ <= kNumSnippets (i.e., the maximum number
157
+ // of bytes needed to make a prediction is smaller or equal to the number of
158
+ // default snippets), then this variable is equal to 1. Otherwise, it is set
159
+ // to kNumSnippets.
160
+ int num_snippets_;
161
+
162
+ // The string used to make a prediction is created by concatenating
163
+ // num_snippets_ snippets of size snippet_size_ = (max_num_bytes_ /
164
+ // num_snippets_) that are equaly spread out throughout the input.
165
+ int snippet_size_;
166
+
167
+ // Default number of snippets to concatenate to produce the string used for
168
+ // language identification. For the actual number of snippets, see
169
+ // num_snippets_.
170
+ static const int kNumSnippets;
171
+ };
172
+
173
+ } // namespace chrome_lang_id
174
+
175
+ #endif // NNET_LANGUAGE_IDENTIFIER_H_