cld3 3.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Gemfile +18 -0
- data/LICENSE +204 -0
- data/LICENSE_CLD3 +203 -0
- data/README.md +22 -0
- data/cld3.gemspec +35 -0
- data/ext/cld3/base.cc +36 -0
- data/ext/cld3/base.h +106 -0
- data/ext/cld3/casts.h +98 -0
- data/ext/cld3/embedding_feature_extractor.cc +51 -0
- data/ext/cld3/embedding_feature_extractor.h +182 -0
- data/ext/cld3/embedding_network.cc +196 -0
- data/ext/cld3/embedding_network.h +186 -0
- data/ext/cld3/embedding_network_params.h +285 -0
- data/ext/cld3/extconf.rb +49 -0
- data/ext/cld3/feature_extractor.cc +137 -0
- data/ext/cld3/feature_extractor.h +633 -0
- data/ext/cld3/feature_extractor.proto +50 -0
- data/ext/cld3/feature_types.cc +72 -0
- data/ext/cld3/feature_types.h +158 -0
- data/ext/cld3/fixunicodevalue.cc +55 -0
- data/ext/cld3/fixunicodevalue.h +69 -0
- data/ext/cld3/float16.h +58 -0
- data/ext/cld3/fml_parser.cc +308 -0
- data/ext/cld3/fml_parser.h +123 -0
- data/ext/cld3/generated_entities.cc +296 -0
- data/ext/cld3/generated_ulscript.cc +678 -0
- data/ext/cld3/generated_ulscript.h +142 -0
- data/ext/cld3/getonescriptspan.cc +1109 -0
- data/ext/cld3/getonescriptspan.h +124 -0
- data/ext/cld3/integral_types.h +37 -0
- data/ext/cld3/lang_id_nn_params.cc +57449 -0
- data/ext/cld3/lang_id_nn_params.h +178 -0
- data/ext/cld3/language_identifier_features.cc +165 -0
- data/ext/cld3/language_identifier_features.h +116 -0
- data/ext/cld3/nnet_language_identifier.cc +380 -0
- data/ext/cld3/nnet_language_identifier.h +175 -0
- data/ext/cld3/nnet_language_identifier_c.cc +72 -0
- data/ext/cld3/offsetmap.cc +478 -0
- data/ext/cld3/offsetmap.h +168 -0
- data/ext/cld3/port.h +143 -0
- data/ext/cld3/registry.cc +28 -0
- data/ext/cld3/registry.h +242 -0
- data/ext/cld3/relevant_script_feature.cc +89 -0
- data/ext/cld3/relevant_script_feature.h +49 -0
- data/ext/cld3/script_detector.h +156 -0
- data/ext/cld3/sentence.proto +77 -0
- data/ext/cld3/sentence_features.cc +29 -0
- data/ext/cld3/sentence_features.h +35 -0
- data/ext/cld3/simple_adder.h +72 -0
- data/ext/cld3/stringpiece.h +81 -0
- data/ext/cld3/task_context.cc +161 -0
- data/ext/cld3/task_context.h +81 -0
- data/ext/cld3/task_context_params.cc +74 -0
- data/ext/cld3/task_context_params.h +54 -0
- data/ext/cld3/task_spec.proto +98 -0
- data/ext/cld3/text_processing.cc +245 -0
- data/ext/cld3/text_processing.h +30 -0
- data/ext/cld3/unicodetext.cc +96 -0
- data/ext/cld3/unicodetext.h +144 -0
- data/ext/cld3/utf8acceptinterchange.h +486 -0
- data/ext/cld3/utf8prop_lettermarkscriptnum.h +1631 -0
- data/ext/cld3/utf8repl_lettermarklower.h +758 -0
- data/ext/cld3/utf8scannot_lettermarkspecial.h +1455 -0
- data/ext/cld3/utf8statetable.cc +1344 -0
- data/ext/cld3/utf8statetable.h +285 -0
- data/ext/cld3/utils.cc +241 -0
- data/ext/cld3/utils.h +144 -0
- data/ext/cld3/workspace.cc +64 -0
- data/ext/cld3/workspace.h +177 -0
- data/lib/cld3.rb +99 -0
- metadata +158 -0
data/cld3.gemspec
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
# Copyright 2017 Akihiko Odaki <akihiko.odaki.4i@stu.hosei.ac.jp>
|
2
|
+
# All Rights Reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
#==============================================================================
|
16
|
+
|
17
|
+
Gem::Specification.new do |gem|
|
18
|
+
gem.name = "cld3"
|
19
|
+
gem.version = "3.1.0"
|
20
|
+
gem.summary = "Compact Language Detector v3 (CLD3)"
|
21
|
+
gem.description = "Compact Language Detector v3 (CLD3) is a neural network model for language identification."
|
22
|
+
gem.license = "Apache-2.0"
|
23
|
+
gem.homepage = "https://github.com/akihikodaki/cld3-ruby"
|
24
|
+
gem.author = "Akihiko Odaki"
|
25
|
+
gem.email = "akihiko.odaki.4i@stu.hosei.ac.jp"
|
26
|
+
gem.required_ruby_version = [ ">= 2.3.0", "< 2.5.0" ]
|
27
|
+
gem.add_dependency "ffi", [ ">= 1.1.0", "< 1.10.0" ]
|
28
|
+
gem.add_development_dependency "rspec", [ ">=2.11.0", "< 3.7.0" ]
|
29
|
+
gem.files = Dir[
|
30
|
+
"Gemfile", "LICENSE", "LICENSE_CLD3", "README.md",
|
31
|
+
"cld3.gemspec", "ext/**/*", "lib/**/*"
|
32
|
+
]
|
33
|
+
gem.require_paths = [ "lib" ]
|
34
|
+
gem.extensions = [ "ext/cld3/extconf.rb" ]
|
35
|
+
end
|
data/ext/cld3/base.cc
ADDED
@@ -0,0 +1,36 @@
|
|
1
|
+
/* Copyright 2016 Google Inc. All Rights Reserved.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
#include "base.h"
|
17
|
+
|
18
|
+
#include <string>
|
19
|
+
#ifdef COMPILER_MSVC
|
20
|
+
#include <sstream>
|
21
|
+
#endif // COMPILER_MSVC
|
22
|
+
|
23
|
+
namespace chrome_lang_id {
|
24
|
+
|
25
|
+
// TODO(abakalov): Pick the most efficient approach.
|
26
|
+
#ifdef COMPILER_MSVC
|
27
|
+
std::string Int64ToString(int64 input) {
|
28
|
+
std::stringstream stream;
|
29
|
+
stream << input;
|
30
|
+
return stream.str();
|
31
|
+
}
|
32
|
+
#else
|
33
|
+
std::string Int64ToString(int64 input) { return std::to_string(input); }
|
34
|
+
#endif // COMPILER_MSVC
|
35
|
+
|
36
|
+
} // namespace chrome_lang_id
|
data/ext/cld3/base.h
ADDED
@@ -0,0 +1,106 @@
|
|
1
|
+
/* Copyright 2016 Google Inc. All Rights Reserved.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
#ifndef BASE_H_
|
17
|
+
#define BASE_H_
|
18
|
+
|
19
|
+
#include <cassert>
|
20
|
+
#include <map>
|
21
|
+
#include <string>
|
22
|
+
#include <vector>
|
23
|
+
|
24
|
+
namespace chrome_lang_id {
|
25
|
+
|
26
|
+
using std::vector;
|
27
|
+
using std::string;
|
28
|
+
using std::map;
|
29
|
+
using std::pair;
|
30
|
+
typedef unsigned int uint32;
|
31
|
+
|
32
|
+
#if LANG_CXX11
|
33
|
+
#define CLD3_DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
34
|
+
TypeName(const TypeName &) = delete; \
|
35
|
+
TypeName &operator=(const TypeName &) = delete
|
36
|
+
#else // C++98 case follows
|
37
|
+
|
38
|
+
// Note that these C++98 implementations cannot completely disallow copying,
|
39
|
+
// as members and friends can still accidentally make elided copies without
|
40
|
+
// triggering a linker error.
|
41
|
+
#define CLD3_DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
42
|
+
TypeName(const TypeName &); \
|
43
|
+
TypeName &operator=(const TypeName &)
|
44
|
+
#endif // LANG_CXX11
|
45
|
+
|
46
|
+
#ifndef CLD3_IMMEDIATE_CRASH
|
47
|
+
#if defined(__GNUC__) || defined(__clang__)
|
48
|
+
#define CLD3_IMMEDIATE_CRASH() __builtin_trap()
|
49
|
+
#else
|
50
|
+
#define CLD3_IMMEDIATE_CRASH() ((void)(*(volatile char *)0 = 0))
|
51
|
+
#endif
|
52
|
+
#endif // CLD3_IMMEDIATE_CRASH
|
53
|
+
|
54
|
+
#define CLD3_CHECK(f) (!(f) ? CLD3_IMMEDIATE_CRASH() : (void)0)
|
55
|
+
|
56
|
+
#if defined(NDEBUG) && !defined(DCHECK_ALWAYS_ON)
|
57
|
+
#define CLD3_DCHECK(f) ((void)0)
|
58
|
+
#else
|
59
|
+
#define CLD3_DCHECK(f) CLD3_CHECK(f)
|
60
|
+
#endif
|
61
|
+
|
62
|
+
#ifndef SWIG
|
63
|
+
typedef int int32;
|
64
|
+
typedef unsigned char uint8; // NOLINT
|
65
|
+
typedef unsigned short uint16; // NOLINT
|
66
|
+
|
67
|
+
// A type to represent a Unicode code-point value. As of Unicode 4.0,
|
68
|
+
// such values require up to 21 bits.
|
69
|
+
// (For type-checking on pointers, make this explicitly signed,
|
70
|
+
// and it should always be the signed version of whatever int32 is.)
|
71
|
+
typedef signed int char32;
|
72
|
+
#endif // SWIG
|
73
|
+
|
74
|
+
#ifdef COMPILER_MSVC
|
75
|
+
typedef __int64 int64;
|
76
|
+
#else
|
77
|
+
typedef long long int64; // NOLINT
|
78
|
+
#endif // COMPILER_MSVC
|
79
|
+
|
80
|
+
#if defined(__GNUC__) && \
|
81
|
+
(__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1))
|
82
|
+
|
83
|
+
// For functions we want to force inline.
|
84
|
+
// Introduced in gcc 3.1.
|
85
|
+
#define CLD3_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline))
|
86
|
+
|
87
|
+
#elif defined(_MSC_VER)
|
88
|
+
#define CLD3_ATTRIBUTE_ALWAYS_INLINE __forceinline
|
89
|
+
#else
|
90
|
+
|
91
|
+
// Other compilers will have to figure it out for themselves.
|
92
|
+
#define CLD3_ATTRIBUTE_ALWAYS_INLINE
|
93
|
+
#endif
|
94
|
+
|
95
|
+
#ifdef INTERNAL_BUILD
|
96
|
+
typedef basic_string<char> bstring;
|
97
|
+
#else
|
98
|
+
typedef std::basic_string<char> bstring;
|
99
|
+
#endif // INTERNAL_BUILD
|
100
|
+
|
101
|
+
// Converts int64 to string.
|
102
|
+
std::string Int64ToString(int64 input);
|
103
|
+
|
104
|
+
} // namespace chrome_lang_id
|
105
|
+
|
106
|
+
#endif // BASE_H_
|
data/ext/cld3/casts.h
ADDED
@@ -0,0 +1,98 @@
|
|
1
|
+
/* Copyright 2016 Google Inc. All Rights Reserved.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
// This code is compiled directly on many platforms, including client
|
17
|
+
// platforms like Windows, Mac, and embedded systems. Before making
|
18
|
+
// any changes here, make sure that you're not breaking any platforms.
|
19
|
+
//
|
20
|
+
|
21
|
+
#ifndef CASTS_H_
|
22
|
+
#define CASTS_H_
|
23
|
+
|
24
|
+
#include <string.h> // for memcpy
|
25
|
+
|
26
|
+
namespace chrome_lang_id {
|
27
|
+
|
28
|
+
// lang_id_bit_cast<Dest,Source> is a template function that implements the
|
29
|
+
// equivalent of "*reinterpret_cast<Dest*>(&source)". We need this in
|
30
|
+
// very low-level functions like the protobuf library and fast math
|
31
|
+
// support.
|
32
|
+
//
|
33
|
+
// float f = 3.14159265358979;
|
34
|
+
// int i = lang_id_bit_cast<int32>(f);
|
35
|
+
// // i = 0x40490fdb
|
36
|
+
//
|
37
|
+
// The classical address-casting method is:
|
38
|
+
//
|
39
|
+
// // WRONG
|
40
|
+
// float f = 3.14159265358979; // WRONG
|
41
|
+
// int i = * reinterpret_cast<int*>(&f); // WRONG
|
42
|
+
//
|
43
|
+
// The address-casting method actually produces undefined behavior
|
44
|
+
// according to ISO C++ specification section 3.10 -15 -. Roughly, this
|
45
|
+
// section says: if an object in memory has one type, and a program
|
46
|
+
// accesses it with a different type, then the result is undefined
|
47
|
+
// behavior for most values of "different type".
|
48
|
+
//
|
49
|
+
// This is true for any cast syntax, either *(int*)&f or
|
50
|
+
// *reinterpret_cast<int*>(&f). And it is particularly true for
|
51
|
+
// conversions between integral lvalues and floating-point lvalues.
|
52
|
+
//
|
53
|
+
// The purpose of 3.10 -15- is to allow optimizing compilers to assume
|
54
|
+
// that expressions with different types refer to different memory. gcc
|
55
|
+
// 4.0.1 has an optimizer that takes advantage of this. So a
|
56
|
+
// non-conforming program quietly produces wildly incorrect output.
|
57
|
+
//
|
58
|
+
// The problem is not the use of reinterpret_cast. The problem is type
|
59
|
+
// punning: holding an object in memory of one type and reading its bits
|
60
|
+
// back using a different type.
|
61
|
+
//
|
62
|
+
// The C++ standard is more subtle and complex than this, but that
|
63
|
+
// is the basic idea.
|
64
|
+
//
|
65
|
+
// Anyways ...
|
66
|
+
//
|
67
|
+
// lang_id_bit_cast<> calls memcpy() which is blessed by the standard,
|
68
|
+
// especially by the example in section 3.9 . Also, of course,
|
69
|
+
// lang_id_bit_cast<> wraps up the nasty logic in one place.
|
70
|
+
//
|
71
|
+
// Fortunately memcpy() is very fast. In optimized mode, with a
|
72
|
+
// constant size, gcc 2.95.3, gcc 4.0.1, and msvc 7.1 produce inline
|
73
|
+
// code with the minimal amount of data movement. On a 32-bit system,
|
74
|
+
// memcpy(d,s,4) compiles to one load and one store, and memcpy(d,s,8)
|
75
|
+
// compiles to two loads and two stores.
|
76
|
+
//
|
77
|
+
// I tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc 7.1.
|
78
|
+
//
|
79
|
+
// WARNING: if Dest or Source is a non-POD type, the result of the memcpy
|
80
|
+
// is likely to surprise you.
|
81
|
+
//
|
82
|
+
// Props to Bill Gibbons for the compile time assertion technique and
|
83
|
+
// Art Komninos and Igor Tandetnik for the msvc experiments.
|
84
|
+
//
|
85
|
+
// -- mec 2005-10-17
|
86
|
+
|
87
|
+
template <class Dest, class Source>
|
88
|
+
inline Dest lang_id_bit_cast(const Source &source) {
|
89
|
+
static_assert(sizeof(Dest) == sizeof(Source), "Sizes do not match");
|
90
|
+
|
91
|
+
Dest dest;
|
92
|
+
memcpy(&dest, &source, sizeof(dest));
|
93
|
+
return dest;
|
94
|
+
}
|
95
|
+
|
96
|
+
} // namespace chrome_lang_id
|
97
|
+
|
98
|
+
#endif // CASTS_H_
|
@@ -0,0 +1,51 @@
|
|
1
|
+
/* Copyright 2016 Google Inc. All Rights Reserved.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
#include "embedding_feature_extractor.h"
|
17
|
+
|
18
|
+
#include <stddef.h>
|
19
|
+
#include <vector>
|
20
|
+
|
21
|
+
#include "feature_extractor.h"
|
22
|
+
#include "feature_types.h"
|
23
|
+
#include "task_context.h"
|
24
|
+
#include "utils.h"
|
25
|
+
|
26
|
+
namespace chrome_lang_id {
|
27
|
+
|
28
|
+
GenericEmbeddingFeatureExtractor::GenericEmbeddingFeatureExtractor() {}
|
29
|
+
|
30
|
+
GenericEmbeddingFeatureExtractor::~GenericEmbeddingFeatureExtractor() {}
|
31
|
+
|
32
|
+
void GenericEmbeddingFeatureExtractor::Setup(TaskContext *context) {
|
33
|
+
// Don't use version to determine how to get feature FML.
|
34
|
+
string features_param = ArgPrefix();
|
35
|
+
features_param += "_features";
|
36
|
+
const string features = context->Get(features_param, "");
|
37
|
+
const string embedding_names =
|
38
|
+
context->Get(GetParamName("embedding_names"), "");
|
39
|
+
const string embedding_dims =
|
40
|
+
context->Get(GetParamName("embedding_dims"), "");
|
41
|
+
embedding_fml_ = utils::Split(features, ';');
|
42
|
+
add_strings_ = context->Get(GetParamName("add_varlen_strings"), false);
|
43
|
+
embedding_names_ = utils::Split(embedding_names, ';');
|
44
|
+
for (const string &dim : utils::Split(embedding_dims, ';')) {
|
45
|
+
embedding_dims_.push_back(utils::ParseUsing<int>(dim, utils::ParseInt32));
|
46
|
+
}
|
47
|
+
}
|
48
|
+
|
49
|
+
void GenericEmbeddingFeatureExtractor::Init(TaskContext *context) {}
|
50
|
+
|
51
|
+
} // namespace chrome_lang_id
|
@@ -0,0 +1,182 @@
|
|
1
|
+
/* Copyright 2016 Google Inc. All Rights Reserved.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
#ifndef EMBEDDING_FEATURE_EXTRACTOR_H_
|
17
|
+
#define EMBEDDING_FEATURE_EXTRACTOR_H_
|
18
|
+
|
19
|
+
#include <memory>
|
20
|
+
#include <string>
|
21
|
+
#include <vector>
|
22
|
+
|
23
|
+
#include "feature_extractor.h"
|
24
|
+
#include "task_context.h"
|
25
|
+
#include "workspace.h"
|
26
|
+
|
27
|
+
namespace chrome_lang_id {
|
28
|
+
|
29
|
+
// An EmbeddingFeatureExtractor manages the extraction of features for
|
30
|
+
// embedding-based models. It wraps a sequence of underlying classes of feature
|
31
|
+
// extractors, along with associated predicate maps. Each class of feature
|
32
|
+
// extractors is associated with a name, e.g., "unigrams", "bigrams".
|
33
|
+
//
|
34
|
+
// The class is split between a generic abstract version,
|
35
|
+
// GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
|
36
|
+
// signature of the ExtractFeatures method) and a typed version.
|
37
|
+
//
|
38
|
+
// The predicate maps must be initialized before use: they can be loaded using
|
39
|
+
// Read() or updated via UpdateMapsForExample.
|
40
|
+
class GenericEmbeddingFeatureExtractor {
|
41
|
+
public:
|
42
|
+
GenericEmbeddingFeatureExtractor();
|
43
|
+
virtual ~GenericEmbeddingFeatureExtractor();
|
44
|
+
|
45
|
+
// Get the prefix string to put in front of all arguments, so they don't
|
46
|
+
// conflict with other embedding models.
|
47
|
+
virtual const string ArgPrefix() const = 0;
|
48
|
+
|
49
|
+
// Sets up predicate maps and embedding space names that are common for all
|
50
|
+
// embedding based feature extractors.
|
51
|
+
virtual void Setup(TaskContext *context);
|
52
|
+
virtual void Init(TaskContext *context);
|
53
|
+
|
54
|
+
// Requests workspace for the underlying feature extractors. This is
|
55
|
+
// implemented in the typed class.
|
56
|
+
virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
|
57
|
+
|
58
|
+
// Number of predicates for the embedding at a given index (vocabulary size.)
|
59
|
+
int EmbeddingSize(int index) const {
|
60
|
+
return generic_feature_extractor(index).GetDomainSize();
|
61
|
+
}
|
62
|
+
|
63
|
+
// Returns number of embedding spaces.
|
64
|
+
int NumEmbeddings() const { return embedding_dims_.size(); }
|
65
|
+
|
66
|
+
// Returns the number of features in the embedding space.
|
67
|
+
int FeatureSize(int idx) const {
|
68
|
+
return generic_feature_extractor(idx).feature_types();
|
69
|
+
}
|
70
|
+
|
71
|
+
// Returns the dimensionality of the embedding space.
|
72
|
+
int EmbeddingDims(int index) const { return embedding_dims_[index]; }
|
73
|
+
|
74
|
+
// Accessor for embedding dims (dimensions of the embedding spaces).
|
75
|
+
const std::vector<int> &embedding_dims() const { return embedding_dims_; }
|
76
|
+
|
77
|
+
const std::vector<string> &embedding_fml() const { return embedding_fml_; }
|
78
|
+
|
79
|
+
// Get parameter name by concatenating the prefix and the original name.
|
80
|
+
string GetParamName(const string ¶m_name) const {
|
81
|
+
string name = ArgPrefix();
|
82
|
+
name += "_";
|
83
|
+
name += param_name;
|
84
|
+
return name;
|
85
|
+
}
|
86
|
+
|
87
|
+
protected:
|
88
|
+
// Provides the generic class with access to the templated extractors. This is
|
89
|
+
// used to get the type information out of the feature extractor without
|
90
|
+
// knowing the specific calling arguments of the extractor itself.
|
91
|
+
virtual const GenericFeatureExtractor &generic_feature_extractor(
|
92
|
+
int idx) const = 0;
|
93
|
+
|
94
|
+
private:
|
95
|
+
// Embedding space names for parameter sharing.
|
96
|
+
std::vector<string> embedding_names_;
|
97
|
+
|
98
|
+
// FML strings for each feature extractor.
|
99
|
+
std::vector<string> embedding_fml_;
|
100
|
+
|
101
|
+
// Size of each of the embedding spaces (maximum predicate id).
|
102
|
+
std::vector<int> embedding_sizes_;
|
103
|
+
|
104
|
+
// Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
|
105
|
+
std::vector<int> embedding_dims_;
|
106
|
+
|
107
|
+
// Whether or not to add string descriptions to converted examples.
|
108
|
+
bool add_strings_;
|
109
|
+
};
|
110
|
+
|
111
|
+
// Templated, object-specific implementation of the
|
112
|
+
// EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
|
113
|
+
// ARGS...> class that has the appropriate FeatureTraits() to ensure that
|
114
|
+
// locator type features work.
|
115
|
+
//
|
116
|
+
// Note: for backwards compatibility purposes, this always reads the FML spec
|
117
|
+
// from "<prefix>_features".
|
118
|
+
template <class EXTRACTOR, class OBJ, class... ARGS>
|
119
|
+
class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
|
120
|
+
public:
|
121
|
+
// Sets up all predicate maps, feature extractors, and flags.
|
122
|
+
void Setup(TaskContext *context) override {
|
123
|
+
GenericEmbeddingFeatureExtractor::Setup(context);
|
124
|
+
feature_extractors_.resize(embedding_fml().size());
|
125
|
+
for (size_t i = 0; i < embedding_fml().size(); ++i) {
|
126
|
+
feature_extractors_[i].Parse(embedding_fml()[i]);
|
127
|
+
feature_extractors_[i].Setup(context);
|
128
|
+
}
|
129
|
+
}
|
130
|
+
|
131
|
+
// Initializes resources needed by the feature extractors.
|
132
|
+
void Init(TaskContext *context) override {
|
133
|
+
GenericEmbeddingFeatureExtractor::Init(context);
|
134
|
+
for (auto &feature_extractor : feature_extractors_) {
|
135
|
+
feature_extractor.Init(context);
|
136
|
+
}
|
137
|
+
}
|
138
|
+
|
139
|
+
// Requests workspaces from the registry. Must be called after Init(), and
|
140
|
+
// before Preprocess().
|
141
|
+
void RequestWorkspaces(WorkspaceRegistry *registry) override {
|
142
|
+
for (auto &feature_extractor : feature_extractors_) {
|
143
|
+
feature_extractor.RequestWorkspaces(registry);
|
144
|
+
}
|
145
|
+
}
|
146
|
+
|
147
|
+
// Must be called on the object one state for each sentence, before any
|
148
|
+
// feature extraction (e.g., UpdateMapsForExample, ExtractSparseFeatures).
|
149
|
+
void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
|
150
|
+
for (auto &feature_extractor : feature_extractors_) {
|
151
|
+
feature_extractor.Preprocess(workspaces, obj);
|
152
|
+
}
|
153
|
+
}
|
154
|
+
|
155
|
+
// Extracts features using the extractors. Note that features must already
|
156
|
+
// be initialized to the correct number of feature extractors. No predicate
|
157
|
+
// mapping is applied.
|
158
|
+
void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
|
159
|
+
ARGS... args,
|
160
|
+
std::vector<FeatureVector> *features) const {
|
161
|
+
for (size_t i = 0; i < feature_extractors_.size(); ++i) {
|
162
|
+
features->at(i).clear();
|
163
|
+
feature_extractors_.at(i).ExtractFeatures(workspaces, obj, args...,
|
164
|
+
&features->at(i));
|
165
|
+
}
|
166
|
+
}
|
167
|
+
|
168
|
+
protected:
|
169
|
+
// Provides generic access to the feature extractors.
|
170
|
+
const GenericFeatureExtractor &generic_feature_extractor(
|
171
|
+
int idx) const override {
|
172
|
+
return feature_extractors_.at(idx);
|
173
|
+
}
|
174
|
+
|
175
|
+
private:
|
176
|
+
// Templated feature extractor class.
|
177
|
+
std::vector<EXTRACTOR> feature_extractors_;
|
178
|
+
};
|
179
|
+
|
180
|
+
} // namespace chrome_lang_id
|
181
|
+
|
182
|
+
#endif // EMBEDDING_FEATURE_EXTRACTOR_H_
|