cld3 3.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Gemfile +18 -0
- data/LICENSE +204 -0
- data/LICENSE_CLD3 +203 -0
- data/README.md +22 -0
- data/cld3.gemspec +35 -0
- data/ext/cld3/base.cc +36 -0
- data/ext/cld3/base.h +106 -0
- data/ext/cld3/casts.h +98 -0
- data/ext/cld3/embedding_feature_extractor.cc +51 -0
- data/ext/cld3/embedding_feature_extractor.h +182 -0
- data/ext/cld3/embedding_network.cc +196 -0
- data/ext/cld3/embedding_network.h +186 -0
- data/ext/cld3/embedding_network_params.h +285 -0
- data/ext/cld3/extconf.rb +49 -0
- data/ext/cld3/feature_extractor.cc +137 -0
- data/ext/cld3/feature_extractor.h +633 -0
- data/ext/cld3/feature_extractor.proto +50 -0
- data/ext/cld3/feature_types.cc +72 -0
- data/ext/cld3/feature_types.h +158 -0
- data/ext/cld3/fixunicodevalue.cc +55 -0
- data/ext/cld3/fixunicodevalue.h +69 -0
- data/ext/cld3/float16.h +58 -0
- data/ext/cld3/fml_parser.cc +308 -0
- data/ext/cld3/fml_parser.h +123 -0
- data/ext/cld3/generated_entities.cc +296 -0
- data/ext/cld3/generated_ulscript.cc +678 -0
- data/ext/cld3/generated_ulscript.h +142 -0
- data/ext/cld3/getonescriptspan.cc +1109 -0
- data/ext/cld3/getonescriptspan.h +124 -0
- data/ext/cld3/integral_types.h +37 -0
- data/ext/cld3/lang_id_nn_params.cc +57449 -0
- data/ext/cld3/lang_id_nn_params.h +178 -0
- data/ext/cld3/language_identifier_features.cc +165 -0
- data/ext/cld3/language_identifier_features.h +116 -0
- data/ext/cld3/nnet_language_identifier.cc +380 -0
- data/ext/cld3/nnet_language_identifier.h +175 -0
- data/ext/cld3/nnet_language_identifier_c.cc +72 -0
- data/ext/cld3/offsetmap.cc +478 -0
- data/ext/cld3/offsetmap.h +168 -0
- data/ext/cld3/port.h +143 -0
- data/ext/cld3/registry.cc +28 -0
- data/ext/cld3/registry.h +242 -0
- data/ext/cld3/relevant_script_feature.cc +89 -0
- data/ext/cld3/relevant_script_feature.h +49 -0
- data/ext/cld3/script_detector.h +156 -0
- data/ext/cld3/sentence.proto +77 -0
- data/ext/cld3/sentence_features.cc +29 -0
- data/ext/cld3/sentence_features.h +35 -0
- data/ext/cld3/simple_adder.h +72 -0
- data/ext/cld3/stringpiece.h +81 -0
- data/ext/cld3/task_context.cc +161 -0
- data/ext/cld3/task_context.h +81 -0
- data/ext/cld3/task_context_params.cc +74 -0
- data/ext/cld3/task_context_params.h +54 -0
- data/ext/cld3/task_spec.proto +98 -0
- data/ext/cld3/text_processing.cc +245 -0
- data/ext/cld3/text_processing.h +30 -0
- data/ext/cld3/unicodetext.cc +96 -0
- data/ext/cld3/unicodetext.h +144 -0
- data/ext/cld3/utf8acceptinterchange.h +486 -0
- data/ext/cld3/utf8prop_lettermarkscriptnum.h +1631 -0
- data/ext/cld3/utf8repl_lettermarklower.h +758 -0
- data/ext/cld3/utf8scannot_lettermarkspecial.h +1455 -0
- data/ext/cld3/utf8statetable.cc +1344 -0
- data/ext/cld3/utf8statetable.h +285 -0
- data/ext/cld3/utils.cc +241 -0
- data/ext/cld3/utils.h +144 -0
- data/ext/cld3/workspace.cc +64 -0
- data/ext/cld3/workspace.h +177 -0
- data/lib/cld3.rb +99 -0
- metadata +158 -0
@@ -0,0 +1,380 @@
|
|
1
|
+
/* Copyright 2016 Google Inc. All Rights Reserved.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
#include "nnet_language_identifier.h"
|
17
|
+
|
18
|
+
#include <math.h>
|
19
|
+
|
20
|
+
#include <algorithm>
|
21
|
+
#include <limits>
|
22
|
+
#include <string>
|
23
|
+
|
24
|
+
#include "base.h"
|
25
|
+
#include "embedding_network.h"
|
26
|
+
#include "registry.h"
|
27
|
+
#include "relevant_script_feature.h"
|
28
|
+
#include "script_span/generated_ulscript.h"
|
29
|
+
#include "script_span/getonescriptspan.h"
|
30
|
+
#include "script_span/text_processing.h"
|
31
|
+
#include "cld_3/protos/sentence.pb.h"
|
32
|
+
#include "sentence_features.h"
|
33
|
+
#include "task_context.h"
|
34
|
+
#include "workspace.h"
|
35
|
+
|
36
|
+
namespace chrome_lang_id {
|
37
|
+
namespace {
|
38
|
+
|
39
|
+
// Struct for accumulating stats for a language as text subsequences of the same
|
40
|
+
// script are processed.
|
41
|
+
struct LangChunksStats {
|
42
|
+
// Sum of probabilities across subsequences.
|
43
|
+
float prob_sum = 0.0;
|
44
|
+
|
45
|
+
// Total number of bytes corresponding to the language.
|
46
|
+
int byte_sum = 0;
|
47
|
+
|
48
|
+
// Number chunks corresponding to the language.
|
49
|
+
int num_chunks = 0;
|
50
|
+
};
|
51
|
+
|
52
|
+
// Compares two pairs based on their values.
|
53
|
+
bool OrderBySecondDescending(const std::pair<string, float> &x,
|
54
|
+
const std::pair<string, float> &y) {
|
55
|
+
if (x.second == y.second) {
|
56
|
+
return x.first < y.first;
|
57
|
+
} else {
|
58
|
+
return x.second > y.second;
|
59
|
+
}
|
60
|
+
}
|
61
|
+
|
62
|
+
// Returns "true" if the languge prediction is reliable based on the
|
63
|
+
// probability, and "false" otherwise.
|
64
|
+
bool ResultIsReliable(const string &language, float probability) {
|
65
|
+
if (language == "hr" || language == "bs") {
|
66
|
+
return (probability >= NNetLanguageIdentifier::kReliabilityHrBsThreshold);
|
67
|
+
} else {
|
68
|
+
return (probability >= NNetLanguageIdentifier::kReliabilityThreshold);
|
69
|
+
}
|
70
|
+
}
|
71
|
+
|
72
|
+
// Finds the number of interchange-valid bytes to process.
|
73
|
+
int FindNumValidBytesToProcess(const string &text) {
|
74
|
+
// Check if the size of the input text can fit into an int. If not, focus on
|
75
|
+
// the first std::numeric_limits<int>::max() bytes.
|
76
|
+
const int doc_text_size =
|
77
|
+
(text.size() < static_cast<size_t>(std::numeric_limits<int>::max()))
|
78
|
+
? static_cast<int>(text.size())
|
79
|
+
: std::numeric_limits<int>::max();
|
80
|
+
|
81
|
+
// Truncate the input text if it is too long and find the span containing
|
82
|
+
// interchange-valid UTF8.
|
83
|
+
const int num_valid_bytes = CLD2::SpanInterchangeValid(
|
84
|
+
text.c_str(),
|
85
|
+
std::min(NNetLanguageIdentifier::kMaxNumInputBytesToConsider,
|
86
|
+
doc_text_size));
|
87
|
+
|
88
|
+
return num_valid_bytes;
|
89
|
+
}
|
90
|
+
} // namespace
|
91
|
+
|
92
|
+
const int NNetLanguageIdentifier::kMinNumBytesToConsider = 140;
|
93
|
+
const int NNetLanguageIdentifier::kMaxNumBytesToConsider = 700;
|
94
|
+
const int NNetLanguageIdentifier::kMaxNumInputBytesToConsider = 10000;
|
95
|
+
const int NNetLanguageIdentifier::kNumSnippets = 5;
|
96
|
+
const char NNetLanguageIdentifier::kUnknown[] = "und";
|
97
|
+
const float NNetLanguageIdentifier::kReliabilityThreshold = 0.7f;
|
98
|
+
const float NNetLanguageIdentifier::kReliabilityHrBsThreshold = 0.5f;
|
99
|
+
|
100
|
+
const string LanguageIdEmbeddingFeatureExtractor::ArgPrefix() const {
|
101
|
+
return "language_identifier";
|
102
|
+
}
|
103
|
+
|
104
|
+
NNetLanguageIdentifier::NNetLanguageIdentifier()
|
105
|
+
: NNetLanguageIdentifier(kMinNumBytesToConsider, kMaxNumBytesToConsider) {}
|
106
|
+
|
107
|
+
static WholeSentenceFeature *cbog_factory() {
|
108
|
+
return new ContinuousBagOfNgramsFunction;
|
109
|
+
}
|
110
|
+
|
111
|
+
static WholeSentenceFeature *rsf_factory() { return new RelevantScriptFeature; }
|
112
|
+
|
113
|
+
static WholeSentenceFeature *sf_factory() { return new ScriptFeature; }
|
114
|
+
|
115
|
+
NNetLanguageIdentifier::NNetLanguageIdentifier(int min_num_bytes,
|
116
|
+
int max_num_bytes)
|
117
|
+
: num_languages_(TaskContextParams::GetNumLanguages()),
|
118
|
+
network_(&nn_params_),
|
119
|
+
min_num_bytes_(min_num_bytes),
|
120
|
+
max_num_bytes_(max_num_bytes) {
|
121
|
+
CLD3_CHECK(max_num_bytes_ > 0);
|
122
|
+
CLD3_CHECK(min_num_bytes_ >= 0);
|
123
|
+
CLD3_CHECK(min_num_bytes_ < max_num_bytes_);
|
124
|
+
|
125
|
+
num_snippets_ = (max_num_bytes_ <= kNumSnippets) ? 1 : kNumSnippets;
|
126
|
+
snippet_size_ = max_num_bytes_ / num_snippets_;
|
127
|
+
|
128
|
+
if (WholeSentenceFeature::registry() == nullptr) {
|
129
|
+
// Create registry for our WholeSentenceFeature(s).
|
130
|
+
RegisterableClass<WholeSentenceFeature>::CreateRegistry(
|
131
|
+
"sentence feature function", "WholeSentenceFeature", __FILE__,
|
132
|
+
__LINE__);
|
133
|
+
}
|
134
|
+
|
135
|
+
// Register our WholeSentenceFeature(s).
|
136
|
+
// Register ContinuousBagOfNgramsFunction feature function.
|
137
|
+
static WholeSentenceFeature::Registry::Registrar cbog_registrar(
|
138
|
+
WholeSentenceFeature::registry(), "continuous-bag-of-ngrams",
|
139
|
+
"ContinuousBagOfNgramsFunction", __FILE__, __LINE__, cbog_factory);
|
140
|
+
|
141
|
+
// Register RelevantScriptFeature feature function.
|
142
|
+
static WholeSentenceFeature::Registry::Registrar rsf_registrar(
|
143
|
+
WholeSentenceFeature::registry(), "continuous-bag-of-relevant-scripts",
|
144
|
+
"RelevantScriptFeature", __FILE__, __LINE__, rsf_factory);
|
145
|
+
|
146
|
+
// Register ScriptFeature feature function.
|
147
|
+
static WholeSentenceFeature::Registry::Registrar sf_registrar(
|
148
|
+
WholeSentenceFeature::registry(), "script", "ScriptFeature", __FILE__,
|
149
|
+
__LINE__, sf_factory);
|
150
|
+
|
151
|
+
// Get the model parameters, set up and initialize the model.
|
152
|
+
TaskContext context;
|
153
|
+
TaskContextParams::ToTaskContext(&context);
|
154
|
+
Setup(&context);
|
155
|
+
Init(&context);
|
156
|
+
}
|
157
|
+
|
158
|
+
NNetLanguageIdentifier::~NNetLanguageIdentifier() {}
|
159
|
+
|
160
|
+
void NNetLanguageIdentifier::Setup(TaskContext *context) {
|
161
|
+
feature_extractor_.Setup(context);
|
162
|
+
}
|
163
|
+
|
164
|
+
void NNetLanguageIdentifier::Init(TaskContext *context) {
|
165
|
+
feature_extractor_.Init(context);
|
166
|
+
feature_extractor_.RequestWorkspaces(&workspace_registry_);
|
167
|
+
}
|
168
|
+
|
169
|
+
void NNetLanguageIdentifier::GetFeatures(
|
170
|
+
Sentence *sentence, std::vector<FeatureVector> *features) const {
|
171
|
+
// Feature workspace set.
|
172
|
+
WorkspaceSet workspace;
|
173
|
+
workspace.Reset(workspace_registry_);
|
174
|
+
feature_extractor_.Preprocess(&workspace, sentence);
|
175
|
+
feature_extractor_.ExtractFeatures(workspace, *sentence, features);
|
176
|
+
}
|
177
|
+
|
178
|
+
// Returns the language name corresponding to the given id.
|
179
|
+
string NNetLanguageIdentifier::GetLanguageName(int language_id) const {
|
180
|
+
CLD3_CHECK(language_id >= 0);
|
181
|
+
CLD3_CHECK(language_id < num_languages_);
|
182
|
+
return TaskContextParams::language_names(language_id);
|
183
|
+
}
|
184
|
+
|
185
|
+
NNetLanguageIdentifier::Result NNetLanguageIdentifier::FindLanguage(
|
186
|
+
const string &text) {
|
187
|
+
const int num_valid_bytes = FindNumValidBytesToProcess(text);
|
188
|
+
|
189
|
+
// Iterate over the input with ScriptScanner to clean up the text (e.g.,
|
190
|
+
// removing digits, punctuation, brackets).
|
191
|
+
// TODO(abakalov): Extract the code that does the clean-up out of
|
192
|
+
// ScriptScanner.
|
193
|
+
CLD2::ScriptScanner ss(text.c_str(), num_valid_bytes, /*is_plain_text=*/true);
|
194
|
+
CLD2::LangSpan script_span;
|
195
|
+
string cleaned;
|
196
|
+
while (ss.GetOneScriptSpanLower(&script_span)) {
|
197
|
+
// script_span has spaces at the beginning and the end, so there is no need
|
198
|
+
// for a delimiter.
|
199
|
+
cleaned.append(script_span.text, script_span.text_bytes);
|
200
|
+
}
|
201
|
+
|
202
|
+
if (static_cast<int>(cleaned.size()) < min_num_bytes_) {
|
203
|
+
return Result();
|
204
|
+
}
|
205
|
+
|
206
|
+
// Copy to a vector because a non-const char* will be needed.
|
207
|
+
std::vector<char> text_to_process;
|
208
|
+
for (size_t i = 0; i < cleaned.size(); ++i) {
|
209
|
+
text_to_process.push_back(cleaned[i]);
|
210
|
+
}
|
211
|
+
text_to_process.push_back('\0');
|
212
|
+
|
213
|
+
// Remove repetitive chunks or ones containing mostly spaces.
|
214
|
+
const int chunk_size = 0; // Use the default.
|
215
|
+
char *text_begin = &text_to_process[0];
|
216
|
+
const int new_length = CLD2::CheapSqueezeInplace(
|
217
|
+
text_begin, text_to_process.size() - 1, chunk_size);
|
218
|
+
if (new_length < min_num_bytes_) {
|
219
|
+
return Result();
|
220
|
+
}
|
221
|
+
|
222
|
+
const string squeezed_text_to_process =
|
223
|
+
SelectTextGivenBeginAndSize(text_begin, new_length);
|
224
|
+
return FindLanguageOfValidUTF8(squeezed_text_to_process);
|
225
|
+
}
|
226
|
+
|
227
|
+
NNetLanguageIdentifier::Result NNetLanguageIdentifier::FindLanguageOfValidUTF8(
|
228
|
+
const string &text) {
|
229
|
+
// Create a Sentence storing the input text.
|
230
|
+
Sentence sentence;
|
231
|
+
sentence.set_text(text);
|
232
|
+
|
233
|
+
// Predict language.
|
234
|
+
// TODO(salcianu): reuse vector<FeatureVector>.
|
235
|
+
std::vector<FeatureVector> features(feature_extractor_.NumEmbeddings());
|
236
|
+
GetFeatures(&sentence, &features);
|
237
|
+
|
238
|
+
EmbeddingNetwork::Vector scores;
|
239
|
+
network_.ComputeFinalScores(features, &scores);
|
240
|
+
int prediction_id = -1;
|
241
|
+
float max_val = -std::numeric_limits<float>::infinity();
|
242
|
+
for (size_t i = 0; i < scores.size(); ++i) {
|
243
|
+
if (scores[i] > max_val) {
|
244
|
+
prediction_id = i;
|
245
|
+
max_val = scores[i];
|
246
|
+
}
|
247
|
+
}
|
248
|
+
|
249
|
+
// Compute probability.
|
250
|
+
Result result;
|
251
|
+
float diff_sum = 0.0;
|
252
|
+
for (size_t i = 0; i < scores.size(); ++i) {
|
253
|
+
diff_sum += exp(scores[i] - max_val);
|
254
|
+
}
|
255
|
+
const float log_sum_exp = max_val + log(diff_sum);
|
256
|
+
result.probability = exp(max_val - log_sum_exp);
|
257
|
+
|
258
|
+
result.language = GetLanguageName(prediction_id);
|
259
|
+
result.is_reliable = ResultIsReliable(result.language, result.probability);
|
260
|
+
result.proportion = 1.0;
|
261
|
+
return result;
|
262
|
+
}
|
263
|
+
|
264
|
+
std::vector<NNetLanguageIdentifier::Result>
|
265
|
+
NNetLanguageIdentifier::FindTopNMostFreqLangs(const string &text,
|
266
|
+
int num_langs) {
|
267
|
+
std::vector<Result> results;
|
268
|
+
|
269
|
+
// Truncate the input text if it is too long and find the span containing
|
270
|
+
// interchange-valid UTF8.
|
271
|
+
const int num_valid_bytes = FindNumValidBytesToProcess(text);
|
272
|
+
if (num_valid_bytes == 0) {
|
273
|
+
while (num_langs-- > 0) {
|
274
|
+
results.emplace_back();
|
275
|
+
}
|
276
|
+
return results;
|
277
|
+
}
|
278
|
+
|
279
|
+
// Process each subsequence of the same script.
|
280
|
+
CLD2::ScriptScanner ss(text.c_str(), num_valid_bytes, /*is_plain_text=*/true);
|
281
|
+
CLD2::LangSpan script_span;
|
282
|
+
std::unordered_map<string, LangChunksStats> lang_stats;
|
283
|
+
int total_num_bytes = 0;
|
284
|
+
Result result;
|
285
|
+
string language;
|
286
|
+
int chunk_size = 0; // Use the default.
|
287
|
+
while (ss.GetOneScriptSpanLower(&script_span)) {
|
288
|
+
const int num_original_span_bytes = script_span.text_bytes;
|
289
|
+
|
290
|
+
// Remove repetitive chunks or ones containing mostly spaces.
|
291
|
+
const int new_length = CLD2::CheapSqueezeInplace(
|
292
|
+
script_span.text, script_span.text_bytes, chunk_size);
|
293
|
+
script_span.text_bytes = new_length;
|
294
|
+
|
295
|
+
if (script_span.text_bytes < min_num_bytes_) {
|
296
|
+
continue;
|
297
|
+
}
|
298
|
+
total_num_bytes += num_original_span_bytes;
|
299
|
+
|
300
|
+
const string selected_text = SelectTextGivenScriptSpan(script_span);
|
301
|
+
result = FindLanguageOfValidUTF8(selected_text);
|
302
|
+
language = result.language;
|
303
|
+
lang_stats[language].byte_sum += num_original_span_bytes;
|
304
|
+
lang_stats[language].prob_sum +=
|
305
|
+
result.probability * num_original_span_bytes;
|
306
|
+
lang_stats[language].num_chunks++;
|
307
|
+
}
|
308
|
+
|
309
|
+
// Sort the languages based on the number of bytes associated with them.
|
310
|
+
// TODO(abakalov): Consider alternative possibly more efficient portable
|
311
|
+
// approaches for finding the top N languages. Given that on average, there
|
312
|
+
// aren't that many languages in the input, it's likely that the benefits will
|
313
|
+
// be negligible (if any).
|
314
|
+
std::vector<std::pair<string, float>> langs_and_byte_counts;
|
315
|
+
for (const auto &entry : lang_stats) {
|
316
|
+
langs_and_byte_counts.emplace_back(entry.first, entry.second.byte_sum);
|
317
|
+
}
|
318
|
+
std::sort(langs_and_byte_counts.begin(), langs_and_byte_counts.end(),
|
319
|
+
OrderBySecondDescending);
|
320
|
+
|
321
|
+
const float byte_sum = static_cast<float>(total_num_bytes);
|
322
|
+
const int num_langs_to_save =
|
323
|
+
std::min(num_langs, static_cast<int>(langs_and_byte_counts.size()));
|
324
|
+
for (int indx = 0; indx < num_langs_to_save; ++indx) {
|
325
|
+
Result result;
|
326
|
+
const string &language = langs_and_byte_counts.at(indx).first;
|
327
|
+
const LangChunksStats &stats = lang_stats.at(language);
|
328
|
+
result.language = language;
|
329
|
+
result.probability = stats.prob_sum / stats.byte_sum;
|
330
|
+
result.proportion = stats.byte_sum / byte_sum;
|
331
|
+
result.is_reliable = ResultIsReliable(language, result.probability);
|
332
|
+
results.push_back(result);
|
333
|
+
}
|
334
|
+
|
335
|
+
int padding_size = num_langs - langs_and_byte_counts.size();
|
336
|
+
while (padding_size-- > 0) {
|
337
|
+
results.emplace_back();
|
338
|
+
}
|
339
|
+
return results;
|
340
|
+
}
|
341
|
+
|
342
|
+
string NNetLanguageIdentifier::SelectTextGivenScriptSpan(
|
343
|
+
const CLD2::LangSpan &script_span) {
|
344
|
+
return SelectTextGivenBeginAndSize(script_span.text, script_span.text_bytes);
|
345
|
+
}
|
346
|
+
|
347
|
+
string NNetLanguageIdentifier::SelectTextGivenBeginAndSize(
|
348
|
+
const char *text_begin, int text_size) {
|
349
|
+
string output_text;
|
350
|
+
|
351
|
+
// If the size of the input is greater than the maxium number of bytes needed
|
352
|
+
// for a prediction, then concatenate snippets that are equally spread out
|
353
|
+
// throughout the input.
|
354
|
+
if (text_size > max_num_bytes_) {
|
355
|
+
const char *snippet_begin = nullptr;
|
356
|
+
const char *snippet_end = text_begin;
|
357
|
+
|
358
|
+
// Number of bytes between the snippets.
|
359
|
+
const int num_skip_bytes =
|
360
|
+
(text_size - max_num_bytes_) / (num_snippets_ + 1);
|
361
|
+
|
362
|
+
for (int i = 0; i < num_snippets_; ++i) {
|
363
|
+
// Using SpanInterchangeValid to find the offsets to ensure that we are
|
364
|
+
// not splitting a character in two.
|
365
|
+
const int actual_num_skip_bytes =
|
366
|
+
CLD2::SpanInterchangeValid(snippet_end, num_skip_bytes);
|
367
|
+
snippet_begin = snippet_end + actual_num_skip_bytes;
|
368
|
+
const int actual_snippet_size =
|
369
|
+
CLD2::SpanInterchangeValid(snippet_begin, snippet_size_);
|
370
|
+
snippet_end = snippet_begin + actual_snippet_size;
|
371
|
+
output_text.append(snippet_begin, actual_snippet_size);
|
372
|
+
output_text.append(" ");
|
373
|
+
}
|
374
|
+
} else {
|
375
|
+
output_text.append(text_begin, text_size);
|
376
|
+
}
|
377
|
+
return output_text;
|
378
|
+
}
|
379
|
+
|
380
|
+
} // namespace chrome_lang_id
|
@@ -0,0 +1,175 @@
|
|
1
|
+
/* Copyright 2016 Google Inc. All Rights Reserved.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
#ifndef NNET_LANGUAGE_IDENTIFIER_H_
|
17
|
+
#define NNET_LANGUAGE_IDENTIFIER_H_
|
18
|
+
|
19
|
+
#include <string>
|
20
|
+
|
21
|
+
#include "base.h"
|
22
|
+
#include "embedding_feature_extractor.h"
|
23
|
+
#include "embedding_network.h"
|
24
|
+
#include "lang_id_nn_params.h"
|
25
|
+
#include "language_identifier_features.h"
|
26
|
+
#include "script_span/getonescriptspan.h"
|
27
|
+
#include "cld_3/protos/sentence.pb.h"
|
28
|
+
#include "sentence_features.h"
|
29
|
+
#include "task_context.h"
|
30
|
+
#include "task_context_params.h"
|
31
|
+
#include "cld_3/protos/task_spec.pb.h"
|
32
|
+
#include "workspace.h"
|
33
|
+
|
34
|
+
namespace chrome_lang_id {
|
35
|
+
|
36
|
+
// Specialization of the EmbeddingFeatureExtractor for extracting from
|
37
|
+
// (Sentence, int).
|
38
|
+
class LanguageIdEmbeddingFeatureExtractor
|
39
|
+
: public EmbeddingFeatureExtractor<WholeSentenceExtractor, Sentence> {
|
40
|
+
public:
|
41
|
+
const string ArgPrefix() const override;
|
42
|
+
};
|
43
|
+
|
44
|
+
// Class for detecting the language of a document.
|
45
|
+
class NNetLanguageIdentifier {
|
46
|
+
public:
|
47
|
+
// Information about a predicted language.
|
48
|
+
struct Result {
|
49
|
+
string language = kUnknown;
|
50
|
+
float probability = 0.0; // Language probability.
|
51
|
+
bool is_reliable = false; // Whether the prediction is reliable.
|
52
|
+
|
53
|
+
// Proportion of bytes associated with the language. If FindLanguage is
|
54
|
+
// called, this variable is set to 1.
|
55
|
+
float proportion = 0.0;
|
56
|
+
};
|
57
|
+
|
58
|
+
NNetLanguageIdentifier();
|
59
|
+
NNetLanguageIdentifier(int min_num_bytes, int max_num_bytes);
|
60
|
+
~NNetLanguageIdentifier();
|
61
|
+
|
62
|
+
// Finds the most likely language for the given text, along with additional
|
63
|
+
// information (e.g., probability). The prediction is based on the first N
|
64
|
+
// bytes where N is the minumum between the number of interchange valid UTF8
|
65
|
+
// bytes and max_num_bytes_. If N is less than min_num_bytes_ long, then this
|
66
|
+
// function returns kUnknown.
|
67
|
+
Result FindLanguage(const string &text);
|
68
|
+
|
69
|
+
// Splits the input text (up to the first byte, if any, that is not
|
70
|
+
// interchange valid UTF8) into spans based on the script, predicts a language
|
71
|
+
// for each span, and returns a vector storing the top num_langs most frequent
|
72
|
+
// languages along with additional information (e.g., proportions). The number
|
73
|
+
// of bytes considered for each span is the minimum between the size of the
|
74
|
+
// span and max_num_bytes_. If more languages are requested than what is
|
75
|
+
// available in the input, then for those cases kUnknown is returned. Also, if
|
76
|
+
// the size of the span is less than min_num_bytes_ long, then the span is
|
77
|
+
// skipped. If the input text is too long, only the first
|
78
|
+
// kMaxNumInputBytesToConsider bytes are processed.
|
79
|
+
std::vector<Result> FindTopNMostFreqLangs(const string &text, int num_langs);
|
80
|
+
|
81
|
+
// String returned when a language is unknown or prediction cannot be made.
|
82
|
+
static const char kUnknown[];
|
83
|
+
|
84
|
+
// Min number of bytes needed to make a prediction if the default constructor
|
85
|
+
// is called.
|
86
|
+
static const int kMinNumBytesToConsider;
|
87
|
+
|
88
|
+
// Max number of bytes to consider to make a prediction if the default
|
89
|
+
// constructor is called.
|
90
|
+
static const int kMaxNumBytesToConsider;
|
91
|
+
|
92
|
+
// Max number of input bytes to process.
|
93
|
+
static const int kMaxNumInputBytesToConsider;
|
94
|
+
|
95
|
+
// Predictions with probability greater than or equal to this threshold are
|
96
|
+
// marked as reliable. This threshold was optimized on a set of text segments
|
97
|
+
// extracted from wikipedia, and results in an overall precision, recall,
|
98
|
+
// and f1 equal to 0.9760, 0.9624, and 0.9692, respectively.
|
99
|
+
static const float kReliabilityThreshold;
|
100
|
+
|
101
|
+
// Reliability threshold for the languages hr and bs.
|
102
|
+
static const float kReliabilityHrBsThreshold;
|
103
|
+
|
104
|
+
private:
|
105
|
+
// Sets up and initializes the model.
|
106
|
+
void Setup(TaskContext *context);
|
107
|
+
void Init(TaskContext *context);
|
108
|
+
|
109
|
+
// Extract features from sentence. On return, FeatureVector features[i]
|
110
|
+
// contains the features for the embedding space #i.
|
111
|
+
void GetFeatures(Sentence *sentence,
|
112
|
+
std::vector<FeatureVector> *features) const;
|
113
|
+
|
114
|
+
// Finds the most likely language for the given text. Assumes that the text is
|
115
|
+
// interchange valid UTF8.
|
116
|
+
Result FindLanguageOfValidUTF8(const string &text);
|
117
|
+
|
118
|
+
// Returns the language name corresponding to the given id.
|
119
|
+
string GetLanguageName(int language_id) const;
|
120
|
+
|
121
|
+
// Concatenates snippets of text equally spread out throughout the input if
|
122
|
+
// the size of the input is greater than the maximum number of bytes needed to
|
123
|
+
// make a prediction. The resulting string is used for language
|
124
|
+
// identification.
|
125
|
+
string SelectTextGivenScriptSpan(const CLD2::LangSpan &script_span);
|
126
|
+
string SelectTextGivenBeginAndSize(const char *text_begin, int text_size);
|
127
|
+
|
128
|
+
// Number of languages.
|
129
|
+
const int num_languages_;
|
130
|
+
|
131
|
+
// Typed feature extractor for embeddings.
|
132
|
+
LanguageIdEmbeddingFeatureExtractor feature_extractor_;
|
133
|
+
|
134
|
+
// The registry of shared workspaces in the feature extractor.
|
135
|
+
WorkspaceRegistry workspace_registry_;
|
136
|
+
|
137
|
+
// Parameters for the neural networks.
|
138
|
+
LangIdNNParams nn_params_;
|
139
|
+
|
140
|
+
// Neural network to use for scoring.
|
141
|
+
EmbeddingNetwork network_;
|
142
|
+
|
143
|
+
// This feature function is not relevant to this class. Adding this variable
|
144
|
+
// ensures that the features are linked.
|
145
|
+
ContinuousBagOfNgramsFunction ngram_function_;
|
146
|
+
|
147
|
+
// Minimum number of bytes needed to make a prediction. If the default
|
148
|
+
// constructor is called, this variable is equal to kMinNumBytesToConsider.
|
149
|
+
int min_num_bytes_;
|
150
|
+
|
151
|
+
// Maximum number of bytes to use to make a prediction. If the default
|
152
|
+
// constructor is called, this variable is equal to kMaxNumBytesToConsider.
|
153
|
+
int max_num_bytes_;
|
154
|
+
|
155
|
+
// Number of snippets to concatenate to produce the string used for language
|
156
|
+
// identification. If max_num_bytes_ <= kNumSnippets (i.e., the maximum number
|
157
|
+
// of bytes needed to make a prediction is smaller or equal to the number of
|
158
|
+
// default snippets), then this variable is equal to 1. Otherwise, it is set
|
159
|
+
// to kNumSnippets.
|
160
|
+
int num_snippets_;
|
161
|
+
|
162
|
+
// The string used to make a prediction is created by concatenating
|
163
|
+
// num_snippets_ snippets of size snippet_size_ = (max_num_bytes_ /
|
164
|
+
// num_snippets_) that are equaly spread out throughout the input.
|
165
|
+
int snippet_size_;
|
166
|
+
|
167
|
+
// Default number of snippets to concatenate to produce the string used for
|
168
|
+
// language identification. For the actual number of snippets, see
|
169
|
+
// num_snippets_.
|
170
|
+
static const int kNumSnippets;
|
171
|
+
};
|
172
|
+
|
173
|
+
} // namespace chrome_lang_id
|
174
|
+
|
175
|
+
#endif // NNET_LANGUAGE_IDENTIFIER_H_
|