cld3 3.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/Gemfile +18 -0
- data/LICENSE +204 -0
- data/LICENSE_CLD3 +203 -0
- data/README.md +22 -0
- data/cld3.gemspec +35 -0
- data/ext/cld3/base.cc +36 -0
- data/ext/cld3/base.h +106 -0
- data/ext/cld3/casts.h +98 -0
- data/ext/cld3/embedding_feature_extractor.cc +51 -0
- data/ext/cld3/embedding_feature_extractor.h +182 -0
- data/ext/cld3/embedding_network.cc +196 -0
- data/ext/cld3/embedding_network.h +186 -0
- data/ext/cld3/embedding_network_params.h +285 -0
- data/ext/cld3/extconf.rb +49 -0
- data/ext/cld3/feature_extractor.cc +137 -0
- data/ext/cld3/feature_extractor.h +633 -0
- data/ext/cld3/feature_extractor.proto +50 -0
- data/ext/cld3/feature_types.cc +72 -0
- data/ext/cld3/feature_types.h +158 -0
- data/ext/cld3/fixunicodevalue.cc +55 -0
- data/ext/cld3/fixunicodevalue.h +69 -0
- data/ext/cld3/float16.h +58 -0
- data/ext/cld3/fml_parser.cc +308 -0
- data/ext/cld3/fml_parser.h +123 -0
- data/ext/cld3/generated_entities.cc +296 -0
- data/ext/cld3/generated_ulscript.cc +678 -0
- data/ext/cld3/generated_ulscript.h +142 -0
- data/ext/cld3/getonescriptspan.cc +1109 -0
- data/ext/cld3/getonescriptspan.h +124 -0
- data/ext/cld3/integral_types.h +37 -0
- data/ext/cld3/lang_id_nn_params.cc +57449 -0
- data/ext/cld3/lang_id_nn_params.h +178 -0
- data/ext/cld3/language_identifier_features.cc +165 -0
- data/ext/cld3/language_identifier_features.h +116 -0
- data/ext/cld3/nnet_language_identifier.cc +380 -0
- data/ext/cld3/nnet_language_identifier.h +175 -0
- data/ext/cld3/nnet_language_identifier_c.cc +72 -0
- data/ext/cld3/offsetmap.cc +478 -0
- data/ext/cld3/offsetmap.h +168 -0
- data/ext/cld3/port.h +143 -0
- data/ext/cld3/registry.cc +28 -0
- data/ext/cld3/registry.h +242 -0
- data/ext/cld3/relevant_script_feature.cc +89 -0
- data/ext/cld3/relevant_script_feature.h +49 -0
- data/ext/cld3/script_detector.h +156 -0
- data/ext/cld3/sentence.proto +77 -0
- data/ext/cld3/sentence_features.cc +29 -0
- data/ext/cld3/sentence_features.h +35 -0
- data/ext/cld3/simple_adder.h +72 -0
- data/ext/cld3/stringpiece.h +81 -0
- data/ext/cld3/task_context.cc +161 -0
- data/ext/cld3/task_context.h +81 -0
- data/ext/cld3/task_context_params.cc +74 -0
- data/ext/cld3/task_context_params.h +54 -0
- data/ext/cld3/task_spec.proto +98 -0
- data/ext/cld3/text_processing.cc +245 -0
- data/ext/cld3/text_processing.h +30 -0
- data/ext/cld3/unicodetext.cc +96 -0
- data/ext/cld3/unicodetext.h +144 -0
- data/ext/cld3/utf8acceptinterchange.h +486 -0
- data/ext/cld3/utf8prop_lettermarkscriptnum.h +1631 -0
- data/ext/cld3/utf8repl_lettermarklower.h +758 -0
- data/ext/cld3/utf8scannot_lettermarkspecial.h +1455 -0
- data/ext/cld3/utf8statetable.cc +1344 -0
- data/ext/cld3/utf8statetable.h +285 -0
- data/ext/cld3/utils.cc +241 -0
- data/ext/cld3/utils.h +144 -0
- data/ext/cld3/workspace.cc +64 -0
- data/ext/cld3/workspace.h +177 -0
- data/lib/cld3.rb +99 -0
- metadata +158 -0
data/cld3.gemspec
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
# Copyright 2017 Akihiko Odaki <akihiko.odaki.4i@stu.hosei.ac.jp>
|
2
|
+
# All Rights Reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
#==============================================================================
|
16
|
+
|
17
|
+
Gem::Specification.new do |gem|
|
18
|
+
gem.name = "cld3"
|
19
|
+
gem.version = "3.1.0"
|
20
|
+
gem.summary = "Compact Language Detector v3 (CLD3)"
|
21
|
+
gem.description = "Compact Language Detector v3 (CLD3) is a neural network model for language identification."
|
22
|
+
gem.license = "Apache-2.0"
|
23
|
+
gem.homepage = "https://github.com/akihikodaki/cld3-ruby"
|
24
|
+
gem.author = "Akihiko Odaki"
|
25
|
+
gem.email = "akihiko.odaki.4i@stu.hosei.ac.jp"
|
26
|
+
gem.required_ruby_version = [ ">= 2.3.0", "< 2.5.0" ]
|
27
|
+
gem.add_dependency "ffi", [ ">= 1.1.0", "< 1.10.0" ]
|
28
|
+
gem.add_development_dependency "rspec", [ ">=2.11.0", "< 3.7.0" ]
|
29
|
+
gem.files = Dir[
|
30
|
+
"Gemfile", "LICENSE", "LICENSE_CLD3", "README.md",
|
31
|
+
"cld3.gemspec", "ext/**/*", "lib/**/*"
|
32
|
+
]
|
33
|
+
gem.require_paths = [ "lib" ]
|
34
|
+
gem.extensions = [ "ext/cld3/extconf.rb" ]
|
35
|
+
end
|
data/ext/cld3/base.cc
ADDED
@@ -0,0 +1,36 @@
|
|
1
|
+
/* Copyright 2016 Google Inc. All Rights Reserved.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
#include "base.h"
|
17
|
+
|
18
|
+
#include <string>
|
19
|
+
#ifdef COMPILER_MSVC
|
20
|
+
#include <sstream>
|
21
|
+
#endif // COMPILER_MSVC
|
22
|
+
|
23
|
+
namespace chrome_lang_id {
|
24
|
+
|
25
|
+
// TODO(abakalov): Pick the most efficient approach.
|
26
|
+
#ifdef COMPILER_MSVC
|
27
|
+
std::string Int64ToString(int64 input) {
|
28
|
+
std::stringstream stream;
|
29
|
+
stream << input;
|
30
|
+
return stream.str();
|
31
|
+
}
|
32
|
+
#else
|
33
|
+
std::string Int64ToString(int64 input) { return std::to_string(input); }
|
34
|
+
#endif // COMPILER_MSVC
|
35
|
+
|
36
|
+
} // namespace chrome_lang_id
|
data/ext/cld3/base.h
ADDED
@@ -0,0 +1,106 @@
|
|
1
|
+
/* Copyright 2016 Google Inc. All Rights Reserved.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
#ifndef BASE_H_
|
17
|
+
#define BASE_H_
|
18
|
+
|
19
|
+
#include <cassert>
|
20
|
+
#include <map>
|
21
|
+
#include <string>
|
22
|
+
#include <vector>
|
23
|
+
|
24
|
+
namespace chrome_lang_id {
|
25
|
+
|
26
|
+
using std::vector;
|
27
|
+
using std::string;
|
28
|
+
using std::map;
|
29
|
+
using std::pair;
|
30
|
+
typedef unsigned int uint32;
|
31
|
+
|
32
|
+
#if LANG_CXX11
|
33
|
+
#define CLD3_DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
34
|
+
TypeName(const TypeName &) = delete; \
|
35
|
+
TypeName &operator=(const TypeName &) = delete
|
36
|
+
#else // C++98 case follows
|
37
|
+
|
38
|
+
// Note that these C++98 implementations cannot completely disallow copying,
|
39
|
+
// as members and friends can still accidentally make elided copies without
|
40
|
+
// triggering a linker error.
|
41
|
+
#define CLD3_DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
42
|
+
TypeName(const TypeName &); \
|
43
|
+
TypeName &operator=(const TypeName &)
|
44
|
+
#endif // LANG_CXX11
|
45
|
+
|
46
|
+
#ifndef CLD3_IMMEDIATE_CRASH
|
47
|
+
#if defined(__GNUC__) || defined(__clang__)
|
48
|
+
#define CLD3_IMMEDIATE_CRASH() __builtin_trap()
|
49
|
+
#else
|
50
|
+
#define CLD3_IMMEDIATE_CRASH() ((void)(*(volatile char *)0 = 0))
|
51
|
+
#endif
|
52
|
+
#endif // CLD3_IMMEDIATE_CRASH
|
53
|
+
|
54
|
+
#define CLD3_CHECK(f) (!(f) ? CLD3_IMMEDIATE_CRASH() : (void)0)
|
55
|
+
|
56
|
+
#if defined(NDEBUG) && !defined(DCHECK_ALWAYS_ON)
|
57
|
+
#define CLD3_DCHECK(f) ((void)0)
|
58
|
+
#else
|
59
|
+
#define CLD3_DCHECK(f) CLD3_CHECK(f)
|
60
|
+
#endif
|
61
|
+
|
62
|
+
#ifndef SWIG
|
63
|
+
typedef int int32;
|
64
|
+
typedef unsigned char uint8; // NOLINT
|
65
|
+
typedef unsigned short uint16; // NOLINT
|
66
|
+
|
67
|
+
// A type to represent a Unicode code-point value. As of Unicode 4.0,
|
68
|
+
// such values require up to 21 bits.
|
69
|
+
// (For type-checking on pointers, make this explicitly signed,
|
70
|
+
// and it should always be the signed version of whatever int32 is.)
|
71
|
+
typedef signed int char32;
|
72
|
+
#endif // SWIG
|
73
|
+
|
74
|
+
#ifdef COMPILER_MSVC
|
75
|
+
typedef __int64 int64;
|
76
|
+
#else
|
77
|
+
typedef long long int64; // NOLINT
|
78
|
+
#endif // COMPILER_MSVC
|
79
|
+
|
80
|
+
#if defined(__GNUC__) && \
|
81
|
+
(__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1))
|
82
|
+
|
83
|
+
// For functions we want to force inline.
|
84
|
+
// Introduced in gcc 3.1.
|
85
|
+
#define CLD3_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline))
|
86
|
+
|
87
|
+
#elif defined(_MSC_VER)
|
88
|
+
#define CLD3_ATTRIBUTE_ALWAYS_INLINE __forceinline
|
89
|
+
#else
|
90
|
+
|
91
|
+
// Other compilers will have to figure it out for themselves.
|
92
|
+
#define CLD3_ATTRIBUTE_ALWAYS_INLINE
|
93
|
+
#endif
|
94
|
+
|
95
|
+
#ifdef INTERNAL_BUILD
|
96
|
+
typedef basic_string<char> bstring;
|
97
|
+
#else
|
98
|
+
typedef std::basic_string<char> bstring;
|
99
|
+
#endif // INTERNAL_BUILD
|
100
|
+
|
101
|
+
// Converts int64 to string.
|
102
|
+
std::string Int64ToString(int64 input);
|
103
|
+
|
104
|
+
} // namespace chrome_lang_id
|
105
|
+
|
106
|
+
#endif // BASE_H_
|
data/ext/cld3/casts.h
ADDED
@@ -0,0 +1,98 @@
|
|
1
|
+
/* Copyright 2016 Google Inc. All Rights Reserved.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
// This code is compiled directly on many platforms, including client
|
17
|
+
// platforms like Windows, Mac, and embedded systems. Before making
|
18
|
+
// any changes here, make sure that you're not breaking any platforms.
|
19
|
+
//
|
20
|
+
|
21
|
+
#ifndef CASTS_H_
|
22
|
+
#define CASTS_H_
|
23
|
+
|
24
|
+
#include <string.h> // for memcpy
|
25
|
+
|
26
|
+
namespace chrome_lang_id {
|
27
|
+
|
28
|
+
// lang_id_bit_cast<Dest,Source> is a template function that implements the
|
29
|
+
// equivalent of "*reinterpret_cast<Dest*>(&source)". We need this in
|
30
|
+
// very low-level functions like the protobuf library and fast math
|
31
|
+
// support.
|
32
|
+
//
|
33
|
+
// float f = 3.14159265358979;
|
34
|
+
// int i = lang_id_bit_cast<int32>(f);
|
35
|
+
// // i = 0x40490fdb
|
36
|
+
//
|
37
|
+
// The classical address-casting method is:
|
38
|
+
//
|
39
|
+
// // WRONG
|
40
|
+
// float f = 3.14159265358979; // WRONG
|
41
|
+
// int i = * reinterpret_cast<int*>(&f); // WRONG
|
42
|
+
//
|
43
|
+
// The address-casting method actually produces undefined behavior
|
44
|
+
// according to ISO C++ specification section 3.10 -15 -. Roughly, this
|
45
|
+
// section says: if an object in memory has one type, and a program
|
46
|
+
// accesses it with a different type, then the result is undefined
|
47
|
+
// behavior for most values of "different type".
|
48
|
+
//
|
49
|
+
// This is true for any cast syntax, either *(int*)&f or
|
50
|
+
// *reinterpret_cast<int*>(&f). And it is particularly true for
|
51
|
+
// conversions between integral lvalues and floating-point lvalues.
|
52
|
+
//
|
53
|
+
// The purpose of 3.10 -15- is to allow optimizing compilers to assume
|
54
|
+
// that expressions with different types refer to different memory. gcc
|
55
|
+
// 4.0.1 has an optimizer that takes advantage of this. So a
|
56
|
+
// non-conforming program quietly produces wildly incorrect output.
|
57
|
+
//
|
58
|
+
// The problem is not the use of reinterpret_cast. The problem is type
|
59
|
+
// punning: holding an object in memory of one type and reading its bits
|
60
|
+
// back using a different type.
|
61
|
+
//
|
62
|
+
// The C++ standard is more subtle and complex than this, but that
|
63
|
+
// is the basic idea.
|
64
|
+
//
|
65
|
+
// Anyways ...
|
66
|
+
//
|
67
|
+
// lang_id_bit_cast<> calls memcpy() which is blessed by the standard,
|
68
|
+
// especially by the example in section 3.9 . Also, of course,
|
69
|
+
// lang_id_bit_cast<> wraps up the nasty logic in one place.
|
70
|
+
//
|
71
|
+
// Fortunately memcpy() is very fast. In optimized mode, with a
|
72
|
+
// constant size, gcc 2.95.3, gcc 4.0.1, and msvc 7.1 produce inline
|
73
|
+
// code with the minimal amount of data movement. On a 32-bit system,
|
74
|
+
// memcpy(d,s,4) compiles to one load and one store, and memcpy(d,s,8)
|
75
|
+
// compiles to two loads and two stores.
|
76
|
+
//
|
77
|
+
// I tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc 7.1.
|
78
|
+
//
|
79
|
+
// WARNING: if Dest or Source is a non-POD type, the result of the memcpy
|
80
|
+
// is likely to surprise you.
|
81
|
+
//
|
82
|
+
// Props to Bill Gibbons for the compile time assertion technique and
|
83
|
+
// Art Komninos and Igor Tandetnik for the msvc experiments.
|
84
|
+
//
|
85
|
+
// -- mec 2005-10-17
|
86
|
+
|
87
|
+
template <class Dest, class Source>
|
88
|
+
inline Dest lang_id_bit_cast(const Source &source) {
|
89
|
+
static_assert(sizeof(Dest) == sizeof(Source), "Sizes do not match");
|
90
|
+
|
91
|
+
Dest dest;
|
92
|
+
memcpy(&dest, &source, sizeof(dest));
|
93
|
+
return dest;
|
94
|
+
}
|
95
|
+
|
96
|
+
} // namespace chrome_lang_id
|
97
|
+
|
98
|
+
#endif // CASTS_H_
|
@@ -0,0 +1,51 @@
|
|
1
|
+
/* Copyright 2016 Google Inc. All Rights Reserved.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
#include "embedding_feature_extractor.h"
|
17
|
+
|
18
|
+
#include <stddef.h>
|
19
|
+
#include <vector>
|
20
|
+
|
21
|
+
#include "feature_extractor.h"
|
22
|
+
#include "feature_types.h"
|
23
|
+
#include "task_context.h"
|
24
|
+
#include "utils.h"
|
25
|
+
|
26
|
+
namespace chrome_lang_id {
|
27
|
+
|
28
|
+
GenericEmbeddingFeatureExtractor::GenericEmbeddingFeatureExtractor() {}
|
29
|
+
|
30
|
+
GenericEmbeddingFeatureExtractor::~GenericEmbeddingFeatureExtractor() {}
|
31
|
+
|
32
|
+
void GenericEmbeddingFeatureExtractor::Setup(TaskContext *context) {
|
33
|
+
// Don't use version to determine how to get feature FML.
|
34
|
+
string features_param = ArgPrefix();
|
35
|
+
features_param += "_features";
|
36
|
+
const string features = context->Get(features_param, "");
|
37
|
+
const string embedding_names =
|
38
|
+
context->Get(GetParamName("embedding_names"), "");
|
39
|
+
const string embedding_dims =
|
40
|
+
context->Get(GetParamName("embedding_dims"), "");
|
41
|
+
embedding_fml_ = utils::Split(features, ';');
|
42
|
+
add_strings_ = context->Get(GetParamName("add_varlen_strings"), false);
|
43
|
+
embedding_names_ = utils::Split(embedding_names, ';');
|
44
|
+
for (const string &dim : utils::Split(embedding_dims, ';')) {
|
45
|
+
embedding_dims_.push_back(utils::ParseUsing<int>(dim, utils::ParseInt32));
|
46
|
+
}
|
47
|
+
}
|
48
|
+
|
49
|
+
void GenericEmbeddingFeatureExtractor::Init(TaskContext *context) {}
|
50
|
+
|
51
|
+
} // namespace chrome_lang_id
|
@@ -0,0 +1,182 @@
|
|
1
|
+
/* Copyright 2016 Google Inc. All Rights Reserved.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
#ifndef EMBEDDING_FEATURE_EXTRACTOR_H_
|
17
|
+
#define EMBEDDING_FEATURE_EXTRACTOR_H_
|
18
|
+
|
19
|
+
#include <memory>
|
20
|
+
#include <string>
|
21
|
+
#include <vector>
|
22
|
+
|
23
|
+
#include "feature_extractor.h"
|
24
|
+
#include "task_context.h"
|
25
|
+
#include "workspace.h"
|
26
|
+
|
27
|
+
namespace chrome_lang_id {
|
28
|
+
|
29
|
+
// An EmbeddingFeatureExtractor manages the extraction of features for
|
30
|
+
// embedding-based models. It wraps a sequence of underlying classes of feature
|
31
|
+
// extractors, along with associated predicate maps. Each class of feature
|
32
|
+
// extractors is associated with a name, e.g., "unigrams", "bigrams".
|
33
|
+
//
|
34
|
+
// The class is split between a generic abstract version,
|
35
|
+
// GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
|
36
|
+
// signature of the ExtractFeatures method) and a typed version.
|
37
|
+
//
|
38
|
+
// The predicate maps must be initialized before use: they can be loaded using
|
39
|
+
// Read() or updated via UpdateMapsForExample.
|
40
|
+
class GenericEmbeddingFeatureExtractor {
|
41
|
+
public:
|
42
|
+
GenericEmbeddingFeatureExtractor();
|
43
|
+
virtual ~GenericEmbeddingFeatureExtractor();
|
44
|
+
|
45
|
+
// Get the prefix string to put in front of all arguments, so they don't
|
46
|
+
// conflict with other embedding models.
|
47
|
+
virtual const string ArgPrefix() const = 0;
|
48
|
+
|
49
|
+
// Sets up predicate maps and embedding space names that are common for all
|
50
|
+
// embedding based feature extractors.
|
51
|
+
virtual void Setup(TaskContext *context);
|
52
|
+
virtual void Init(TaskContext *context);
|
53
|
+
|
54
|
+
// Requests workspace for the underlying feature extractors. This is
|
55
|
+
// implemented in the typed class.
|
56
|
+
virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
|
57
|
+
|
58
|
+
// Number of predicates for the embedding at a given index (vocabulary size.)
|
59
|
+
int EmbeddingSize(int index) const {
|
60
|
+
return generic_feature_extractor(index).GetDomainSize();
|
61
|
+
}
|
62
|
+
|
63
|
+
// Returns number of embedding spaces.
|
64
|
+
int NumEmbeddings() const { return embedding_dims_.size(); }
|
65
|
+
|
66
|
+
// Returns the number of features in the embedding space.
|
67
|
+
int FeatureSize(int idx) const {
|
68
|
+
return generic_feature_extractor(idx).feature_types();
|
69
|
+
}
|
70
|
+
|
71
|
+
// Returns the dimensionality of the embedding space.
|
72
|
+
int EmbeddingDims(int index) const { return embedding_dims_[index]; }
|
73
|
+
|
74
|
+
// Accessor for embedding dims (dimensions of the embedding spaces).
|
75
|
+
const std::vector<int> &embedding_dims() const { return embedding_dims_; }
|
76
|
+
|
77
|
+
const std::vector<string> &embedding_fml() const { return embedding_fml_; }
|
78
|
+
|
79
|
+
// Get parameter name by concatenating the prefix and the original name.
|
80
|
+
string GetParamName(const string ¶m_name) const {
|
81
|
+
string name = ArgPrefix();
|
82
|
+
name += "_";
|
83
|
+
name += param_name;
|
84
|
+
return name;
|
85
|
+
}
|
86
|
+
|
87
|
+
protected:
|
88
|
+
// Provides the generic class with access to the templated extractors. This is
|
89
|
+
// used to get the type information out of the feature extractor without
|
90
|
+
// knowing the specific calling arguments of the extractor itself.
|
91
|
+
virtual const GenericFeatureExtractor &generic_feature_extractor(
|
92
|
+
int idx) const = 0;
|
93
|
+
|
94
|
+
private:
|
95
|
+
// Embedding space names for parameter sharing.
|
96
|
+
std::vector<string> embedding_names_;
|
97
|
+
|
98
|
+
// FML strings for each feature extractor.
|
99
|
+
std::vector<string> embedding_fml_;
|
100
|
+
|
101
|
+
// Size of each of the embedding spaces (maximum predicate id).
|
102
|
+
std::vector<int> embedding_sizes_;
|
103
|
+
|
104
|
+
// Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
|
105
|
+
std::vector<int> embedding_dims_;
|
106
|
+
|
107
|
+
// Whether or not to add string descriptions to converted examples.
|
108
|
+
bool add_strings_;
|
109
|
+
};
|
110
|
+
|
111
|
+
// Templated, object-specific implementation of the
|
112
|
+
// EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
|
113
|
+
// ARGS...> class that has the appropriate FeatureTraits() to ensure that
|
114
|
+
// locator type features work.
|
115
|
+
//
|
116
|
+
// Note: for backwards compatibility purposes, this always reads the FML spec
|
117
|
+
// from "<prefix>_features".
|
118
|
+
template <class EXTRACTOR, class OBJ, class... ARGS>
|
119
|
+
class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
|
120
|
+
public:
|
121
|
+
// Sets up all predicate maps, feature extractors, and flags.
|
122
|
+
void Setup(TaskContext *context) override {
|
123
|
+
GenericEmbeddingFeatureExtractor::Setup(context);
|
124
|
+
feature_extractors_.resize(embedding_fml().size());
|
125
|
+
for (size_t i = 0; i < embedding_fml().size(); ++i) {
|
126
|
+
feature_extractors_[i].Parse(embedding_fml()[i]);
|
127
|
+
feature_extractors_[i].Setup(context);
|
128
|
+
}
|
129
|
+
}
|
130
|
+
|
131
|
+
// Initializes resources needed by the feature extractors.
|
132
|
+
void Init(TaskContext *context) override {
|
133
|
+
GenericEmbeddingFeatureExtractor::Init(context);
|
134
|
+
for (auto &feature_extractor : feature_extractors_) {
|
135
|
+
feature_extractor.Init(context);
|
136
|
+
}
|
137
|
+
}
|
138
|
+
|
139
|
+
// Requests workspaces from the registry. Must be called after Init(), and
|
140
|
+
// before Preprocess().
|
141
|
+
void RequestWorkspaces(WorkspaceRegistry *registry) override {
|
142
|
+
for (auto &feature_extractor : feature_extractors_) {
|
143
|
+
feature_extractor.RequestWorkspaces(registry);
|
144
|
+
}
|
145
|
+
}
|
146
|
+
|
147
|
+
// Must be called on the object one state for each sentence, before any
|
148
|
+
// feature extraction (e.g., UpdateMapsForExample, ExtractSparseFeatures).
|
149
|
+
void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
|
150
|
+
for (auto &feature_extractor : feature_extractors_) {
|
151
|
+
feature_extractor.Preprocess(workspaces, obj);
|
152
|
+
}
|
153
|
+
}
|
154
|
+
|
155
|
+
// Extracts features using the extractors. Note that features must already
|
156
|
+
// be initialized to the correct number of feature extractors. No predicate
|
157
|
+
// mapping is applied.
|
158
|
+
void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
|
159
|
+
ARGS... args,
|
160
|
+
std::vector<FeatureVector> *features) const {
|
161
|
+
for (size_t i = 0; i < feature_extractors_.size(); ++i) {
|
162
|
+
features->at(i).clear();
|
163
|
+
feature_extractors_.at(i).ExtractFeatures(workspaces, obj, args...,
|
164
|
+
&features->at(i));
|
165
|
+
}
|
166
|
+
}
|
167
|
+
|
168
|
+
protected:
|
169
|
+
// Provides generic access to the feature extractors.
|
170
|
+
const GenericFeatureExtractor &generic_feature_extractor(
|
171
|
+
int idx) const override {
|
172
|
+
return feature_extractors_.at(idx);
|
173
|
+
}
|
174
|
+
|
175
|
+
private:
|
176
|
+
// Templated feature extractor class.
|
177
|
+
std::vector<EXTRACTOR> feature_extractors_;
|
178
|
+
};
|
179
|
+
|
180
|
+
} // namespace chrome_lang_id
|
181
|
+
|
182
|
+
#endif // EMBEDDING_FEATURE_EXTRACTOR_H_
|