informers 0.2.0 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +70 -95
- data/lib/informers/configs.rb +48 -0
- data/lib/informers/env.rb +14 -0
- data/lib/informers/model.rb +31 -0
- data/lib/informers/models.rb +294 -0
- data/lib/informers/pipelines.rb +439 -0
- data/lib/informers/tokenizers.rb +141 -0
- data/lib/informers/utils/core.rb +7 -0
- data/lib/informers/utils/hub.rb +240 -0
- data/lib/informers/utils/math.rb +44 -0
- data/lib/informers/utils/tensor.rb +26 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +29 -9
- metadata +21 -41
- data/lib/informers/feature_extraction.rb +0 -59
- data/lib/informers/fill_mask.rb +0 -109
- data/lib/informers/ner.rb +0 -106
- data/lib/informers/question_answering.rb +0 -197
- data/lib/informers/sentiment_analysis.rb +0 -72
- data/lib/informers/text_generation.rb +0 -54
- data/vendor/LICENSE-bert.txt +0 -202
- data/vendor/LICENSE-blingfire.txt +0 -21
- data/vendor/LICENSE-gpt2.txt +0 -24
- data/vendor/LICENSE-roberta.txt +0 -21
- data/vendor/bert_base_cased_tok.bin +0 -0
- data/vendor/bert_base_tok.bin +0 -0
- data/vendor/gpt2.bin +0 -0
- data/vendor/gpt2.i2w +0 -0
- data/vendor/roberta.bin +0 -0
- data/vendor/roberta.i2w +0 -0
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 3abc738d8975839b873bc5e07bb95305d455a9ac1eec94c432415b713411f20b
|
4
|
+
data.tar.gz: b9c36794c33316378752dd816fb517714c6d8186062562a778d3c8539ba7d79a
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ce05bfcdebce333fd6b5abefca703850d3a6d6a50c3c1589bf675e91ae24b424f2e43e6bc0270ad4ea8a520f5be9d636c5e8a5a66deae2c0183adae6cbc517aa
|
7
|
+
data.tar.gz: 6cc9b08b6e0f9e8ea23f306c0c460dc2557e4ee5113ef26300b517608485ea528fcb9254d51f395c37b557bf1728051c2c3dd8a20a25b5bd4826832a4ff30bf8
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,14 @@
|
|
1
|
+
## 1.0.1 (2024-08-27)
|
2
|
+
|
3
|
+
- Added support for `Supabase/gte-small` to `Model`
|
4
|
+
- Fixed error with downloads
|
5
|
+
|
6
|
+
## 1.0.0 (2024-08-26)
|
7
|
+
|
8
|
+
- Replaced task classes with `pipeline` method
|
9
|
+
- Added `Model` class
|
10
|
+
- Dropped support for Ruby < 3.1
|
11
|
+
|
1
12
|
## 0.2.0 (2022-09-06)
|
2
13
|
|
3
14
|
- Added support for `optimum` and `transformers.onnx` models
|
data/README.md
CHANGED
@@ -1,15 +1,10 @@
|
|
1
1
|
# Informers
|
2
2
|
|
3
|
-
:
|
3
|
+
:fire: Fast [transformer](https://github.com/xenova/transformers.js) inference for Ruby
|
4
4
|
|
5
|
-
|
5
|
+
For non-ONNX models, check out [Transformers.rb](https://github.com/ankane/transformers-ruby) :slightly_smiling_face:
|
6
6
|
|
7
|
-
|
8
|
-
- Question answering
|
9
|
-
- Named-entity recognition
|
10
|
-
- Text generation
|
11
|
-
|
12
|
-
[](https://github.com/ankane/informers/actions)
|
7
|
+
[](https://github.com/ankane/informers/actions)
|
13
8
|
|
14
9
|
## Installation
|
15
10
|
|
@@ -21,140 +16,122 @@ gem "informers"
|
|
21
16
|
|
22
17
|
## Getting Started
|
23
18
|
|
24
|
-
- [
|
25
|
-
- [
|
26
|
-
- [Named-entity recognition](#named-entity-recognition)
|
27
|
-
- [Text generation](#text-generation)
|
28
|
-
- [Feature extraction](#feature-extraction)
|
29
|
-
- [Fill mask](#fill-mask)
|
19
|
+
- [Models](#models)
|
20
|
+
- [Pipelines](#pipelines)
|
30
21
|
|
31
|
-
|
22
|
+
## Models
|
32
23
|
|
33
|
-
|
24
|
+
### sentence-transformers/all-MiniLM-L6-v2
|
34
25
|
|
35
|
-
|
26
|
+
[Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
|
36
27
|
|
37
28
|
```ruby
|
38
|
-
|
39
|
-
model.predict("This is super cool")
|
40
|
-
```
|
29
|
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
41
30
|
|
42
|
-
|
43
|
-
|
44
|
-
```ruby
|
45
|
-
{label: "positive", score: 0.999855186578301}
|
31
|
+
model = Informers::Model.new("sentence-transformers/all-MiniLM-L6-v2")
|
32
|
+
embeddings = model.embed(sentences)
|
46
33
|
```
|
47
34
|
|
48
|
-
|
35
|
+
For a quantized version, use:
|
49
36
|
|
50
37
|
```ruby
|
51
|
-
model.
|
38
|
+
model = Informers::Model.new("Xenova/all-MiniLM-L6-v2", quantized: true)
|
52
39
|
```
|
53
40
|
|
54
|
-
###
|
41
|
+
### Xenova/multi-qa-MiniLM-L6-cos-v1
|
55
42
|
|
56
|
-
|
57
|
-
|
58
|
-
Ask a question with some context
|
43
|
+
[Docs](https://huggingface.co/Xenova/multi-qa-MiniLM-L6-cos-v1)
|
59
44
|
|
60
45
|
```ruby
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
)
|
46
|
+
query = "How many people live in London?"
|
47
|
+
docs = ["Around 9 Million people live in London", "London is known for its financial district"]
|
48
|
+
|
49
|
+
model = Informers::Model.new("Xenova/multi-qa-MiniLM-L6-cos-v1")
|
50
|
+
query_embedding = model.embed(query)
|
51
|
+
doc_embeddings = model.embed(docs)
|
52
|
+
scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
|
53
|
+
doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
|
66
54
|
```
|
67
55
|
|
68
|
-
|
56
|
+
### mixedbread-ai/mxbai-embed-large-v1
|
57
|
+
|
58
|
+
[Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
|
69
59
|
|
70
60
|
```ruby
|
71
|
-
|
72
|
-
|
61
|
+
def transform_query(query)
|
62
|
+
"Represent this sentence for searching relevant passages: #{query}"
|
63
|
+
end
|
64
|
+
|
65
|
+
docs = [
|
66
|
+
transform_query("puppy"),
|
67
|
+
"The dog is barking",
|
68
|
+
"The cat is purring"
|
69
|
+
]
|
73
70
|
|
74
|
-
|
71
|
+
model = Informers::Model.new("mixedbread-ai/mxbai-embed-large-v1")
|
72
|
+
embeddings = model.embed(docs)
|
73
|
+
```
|
75
74
|
|
76
|
-
|
75
|
+
### Supabase/gte-small
|
77
76
|
|
78
|
-
|
77
|
+
[Docs](https://huggingface.co/Supabase/gte-small)
|
79
78
|
|
80
79
|
```ruby
|
81
|
-
|
82
|
-
model.predict("Nat works at GitHub in San Francisco")
|
83
|
-
```
|
80
|
+
sentences = ["That is a happy person", "That is a very happy person"]
|
84
81
|
|
85
|
-
|
86
|
-
|
87
|
-
```ruby
|
88
|
-
[
|
89
|
-
{text: "Nat", tag: "person", score: 0.9840519576513487, start: 0, end: 3},
|
90
|
-
{text: "GitHub", tag: "org", score: 0.9426134775785775, start: 13, end: 19},
|
91
|
-
{text: "San Francisco", tag: "location", score: 0.9952414982243061, start: 23, end: 36}
|
92
|
-
]
|
82
|
+
model = Informers::Model.new("Supabase/gte-small")
|
83
|
+
embeddings = model.embed(sentences)
|
93
84
|
```
|
94
85
|
|
95
|
-
|
96
|
-
|
97
|
-
First, export the [pretrained model](tools/export.md).
|
86
|
+
## Pipelines
|
98
87
|
|
99
|
-
|
88
|
+
Named-entity recognition
|
100
89
|
|
101
90
|
```ruby
|
102
|
-
|
103
|
-
|
91
|
+
ner = Informers.pipeline("ner")
|
92
|
+
ner.("Ruby is a programming language created by Matz")
|
104
93
|
```
|
105
94
|
|
106
|
-
|
95
|
+
Sentiment analysis
|
107
96
|
|
108
|
-
```
|
109
|
-
|
97
|
+
```ruby
|
98
|
+
classifier = Informers.pipeline("sentiment-analysis")
|
99
|
+
classifier.("We are very happy to show you the 🤗 Transformers library.")
|
110
100
|
```
|
111
101
|
|
112
|
-
|
113
|
-
|
114
|
-
First, export a [pretrained model](tools/export.md).
|
102
|
+
Question answering
|
115
103
|
|
116
104
|
```ruby
|
117
|
-
|
118
|
-
|
105
|
+
qa = Informers.pipeline("question-answering")
|
106
|
+
qa.("Who invented Ruby?", "Ruby is a programming language created by Matz")
|
119
107
|
```
|
120
108
|
|
121
|
-
|
122
|
-
|
123
|
-
First, export a [pretrained model](tools/export.md).
|
109
|
+
Feature extraction
|
124
110
|
|
125
111
|
```ruby
|
126
|
-
|
127
|
-
|
112
|
+
extractor = Informers.pipeline("feature-extraction")
|
113
|
+
extractor.("We are very happy to show you the 🤗 Transformers library.")
|
128
114
|
```
|
129
115
|
|
130
|
-
##
|
131
|
-
|
132
|
-
Task | Description | Contributor | License | Link
|
133
|
-
--- | --- | --- | --- | ---
|
134
|
-
Sentiment analysis | DistilBERT fine-tuned on SST-2 | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english)
|
135
|
-
Question answering | DistilBERT fine-tuned on SQuAD | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-cased-distilled-squad)
|
136
|
-
Named-entity recognition | BERT fine-tuned on CoNLL03 | Bayerische Staatsbibliothek | In-progress | [Link](https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english)
|
137
|
-
Text generation | GPT-2 | OpenAI | [Custom](https://github.com/openai/gpt-2/blob/master/LICENSE) | [Link](https://huggingface.co/gpt2)
|
116
|
+
## Credits
|
138
117
|
|
139
|
-
|
118
|
+
This library was ported from [Transformers.js](https://github.com/xenova/transformers.js) and is available under the same license.
|
140
119
|
|
141
|
-
##
|
120
|
+
## Upgrading
|
142
121
|
|
143
|
-
|
122
|
+
### 1.0
|
144
123
|
|
145
|
-
|
146
|
-
trove push sentiment-analysis.onnx
|
147
|
-
```
|
124
|
+
Task classes have been replaced with the `pipeline` method.
|
148
125
|
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
- [Transformers](https://github.com/huggingface/transformers) for transformer models
|
154
|
-
- [Bling Fire](https://github.com/microsoft/BlingFire) and [BERT](https://github.com/google-research/bert) for high-performance text tokenization
|
155
|
-
- [ONNX Runtime](https://github.com/Microsoft/onnxruntime) for high-performance inference
|
126
|
+
```ruby
|
127
|
+
# before
|
128
|
+
model = Informers::SentimentAnalysis.new("sentiment-analysis.onnx")
|
129
|
+
model.predict("This is super cool")
|
156
130
|
|
157
|
-
|
131
|
+
# after
|
132
|
+
model = Informers.pipeline("sentiment-analysis")
|
133
|
+
model.("This is super cool")
|
134
|
+
```
|
158
135
|
|
159
136
|
## History
|
160
137
|
|
@@ -175,7 +152,5 @@ To get started with development:
|
|
175
152
|
git clone https://github.com/ankane/informers.git
|
176
153
|
cd informers
|
177
154
|
bundle install
|
178
|
-
|
179
|
-
export MODELS_PATH=path/to/onnx/models
|
180
155
|
bundle exec rake test
|
181
156
|
```
|
@@ -0,0 +1,48 @@
|
|
1
|
+
module Informers
|
2
|
+
class PretrainedConfig
|
3
|
+
attr_reader :model_type, :problem_type, :id2label
|
4
|
+
|
5
|
+
def initialize(config_json)
|
6
|
+
@is_encoder_decoder = false
|
7
|
+
|
8
|
+
@model_type = config_json["model_type"]
|
9
|
+
@problem_type = config_json["problem_type"]
|
10
|
+
@id2label = config_json["id2label"]
|
11
|
+
end
|
12
|
+
|
13
|
+
def [](key)
|
14
|
+
instance_variable_get("@#{key}")
|
15
|
+
end
|
16
|
+
|
17
|
+
def self.from_pretrained(
|
18
|
+
pretrained_model_name_or_path,
|
19
|
+
progress_callback: nil,
|
20
|
+
config: nil,
|
21
|
+
cache_dir: nil,
|
22
|
+
local_files_only: false,
|
23
|
+
revision: "main",
|
24
|
+
**kwargs
|
25
|
+
)
|
26
|
+
data = config || load_config(
|
27
|
+
pretrained_model_name_or_path,
|
28
|
+
progress_callback:,
|
29
|
+
config:,
|
30
|
+
cache_dir:,
|
31
|
+
local_files_only:,
|
32
|
+
revision:
|
33
|
+
)
|
34
|
+
new(data)
|
35
|
+
end
|
36
|
+
|
37
|
+
def self.load_config(pretrained_model_name_or_path, **options)
|
38
|
+
info = Utils::Hub.get_model_json(pretrained_model_name_or_path, "config.json", true, **options)
|
39
|
+
info
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
class AutoConfig
|
44
|
+
def self.from_pretrained(...)
|
45
|
+
PretrainedConfig.from_pretrained(...)
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Informers
|
2
|
+
CACHE_HOME = ENV.fetch("XDG_CACHE_HOME", File.join(ENV.fetch("HOME"), ".cache"))
|
3
|
+
DEFAULT_CACHE_DIR = File.expand_path(File.join(CACHE_HOME, "informers"))
|
4
|
+
|
5
|
+
class << self
|
6
|
+
attr_accessor :allow_remote_models, :remote_host, :remote_path_template, :cache_dir
|
7
|
+
end
|
8
|
+
|
9
|
+
self.allow_remote_models = ENV["INFORMERS_OFFLINE"].to_s.empty?
|
10
|
+
self.remote_host = "https://huggingface.co/"
|
11
|
+
self.remote_path_template = "{model}/resolve/{revision}/"
|
12
|
+
|
13
|
+
self.cache_dir = DEFAULT_CACHE_DIR
|
14
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
module Informers
|
2
|
+
class Model
|
3
|
+
def initialize(model_id, quantized: false)
|
4
|
+
@model_id = model_id
|
5
|
+
@model = Informers.pipeline("feature-extraction", model_id, quantized: quantized)
|
6
|
+
|
7
|
+
# TODO better pattern
|
8
|
+
if model_id == "sentence-transformers/all-MiniLM-L6-v2"
|
9
|
+
@model.instance_variable_get(:@model).instance_variable_set(:@output_names, ["sentence_embedding"])
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
def embed(texts)
|
14
|
+
is_batched = texts.is_a?(Array)
|
15
|
+
texts = [texts] unless is_batched
|
16
|
+
|
17
|
+
case @model_id
|
18
|
+
when "sentence-transformers/all-MiniLM-L6-v2"
|
19
|
+
output = @model.(texts)
|
20
|
+
when "Xenova/all-MiniLM-L6-v2", "Xenova/multi-qa-MiniLM-L6-cos-v1", "Supabase/gte-small"
|
21
|
+
output = @model.(texts, pooling: "mean", normalize: true)
|
22
|
+
when "mixedbread-ai/mxbai-embed-large-v1"
|
23
|
+
output = @model.(texts, pooling: "cls")
|
24
|
+
else
|
25
|
+
raise Error, "model not supported: #{@model_id}"
|
26
|
+
end
|
27
|
+
|
28
|
+
is_batched ? output : output[0]
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
@@ -0,0 +1,294 @@
|
|
1
|
+
module Informers
|
2
|
+
MODEL_TYPES = {
|
3
|
+
EncoderOnly: 0,
|
4
|
+
EncoderDecoder: 1,
|
5
|
+
Seq2Seq: 2,
|
6
|
+
Vision2Seq: 3,
|
7
|
+
DecoderOnly: 4,
|
8
|
+
MaskGeneration: 5
|
9
|
+
}
|
10
|
+
|
11
|
+
# NOTE: These will be populated fully later
|
12
|
+
MODEL_TYPE_MAPPING = {}
|
13
|
+
MODEL_NAME_TO_CLASS_MAPPING = {}
|
14
|
+
MODEL_CLASS_TO_NAME_MAPPING = {}
|
15
|
+
|
16
|
+
class PretrainedMixin
|
17
|
+
def self.from_pretrained(
|
18
|
+
pretrained_model_name_or_path,
|
19
|
+
quantized: true,
|
20
|
+
progress_callback: nil,
|
21
|
+
config: nil,
|
22
|
+
cache_dir: nil,
|
23
|
+
local_files_only: false,
|
24
|
+
revision: "main",
|
25
|
+
model_file_name: nil
|
26
|
+
)
|
27
|
+
options = {
|
28
|
+
quantized:,
|
29
|
+
progress_callback:,
|
30
|
+
config:,
|
31
|
+
cache_dir:,
|
32
|
+
local_files_only:,
|
33
|
+
revision:,
|
34
|
+
model_file_name:
|
35
|
+
}
|
36
|
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **options)
|
37
|
+
if options[:config].nil?
|
38
|
+
# If no config was passed, reuse this config for future processing
|
39
|
+
options[:config] = config
|
40
|
+
end
|
41
|
+
|
42
|
+
if !const_defined?(:MODEL_CLASS_MAPPINGS)
|
43
|
+
raise Error, "`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: #{name}"
|
44
|
+
end
|
45
|
+
|
46
|
+
const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|
|
47
|
+
model_info = model_class_mapping[config.model_type]
|
48
|
+
if !model_info
|
49
|
+
next # Item not found in this mapping
|
50
|
+
end
|
51
|
+
return model_info[1].from_pretrained(pretrained_model_name_or_path, **options)
|
52
|
+
end
|
53
|
+
|
54
|
+
if const_defined?(:BASE_IF_FAIL)
|
55
|
+
warn "Unknown model class #{config.model_type.inspect}, attempting to construct from base class."
|
56
|
+
PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)
|
57
|
+
else
|
58
|
+
raise Error, "Unsupported model type: #{config.model_type}"
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
class PreTrainedModel
|
64
|
+
attr_reader :config
|
65
|
+
|
66
|
+
def initialize(config, session)
|
67
|
+
super()
|
68
|
+
|
69
|
+
@config = config
|
70
|
+
@session = session
|
71
|
+
|
72
|
+
@output_names = nil
|
73
|
+
|
74
|
+
model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
|
75
|
+
model_type = MODEL_TYPE_MAPPING[model_name]
|
76
|
+
|
77
|
+
case model_type
|
78
|
+
when MODEL_TYPES[:DecoderOnly]
|
79
|
+
raise Todo
|
80
|
+
when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
|
81
|
+
raise Todo
|
82
|
+
when MODEL_TYPES[:EncoderDecoder]
|
83
|
+
raise Todo
|
84
|
+
else
|
85
|
+
@forward = method(:encoder_forward)
|
86
|
+
end
|
87
|
+
end
|
88
|
+
|
89
|
+
def self.from_pretrained(
|
90
|
+
pretrained_model_name_or_path,
|
91
|
+
quantized: true,
|
92
|
+
progress_callback: nil,
|
93
|
+
config: nil,
|
94
|
+
cache_dir: nil,
|
95
|
+
local_files_only: false,
|
96
|
+
revision: "main",
|
97
|
+
model_file_name: nil
|
98
|
+
)
|
99
|
+
options = {
|
100
|
+
quantized:,
|
101
|
+
progress_callback:,
|
102
|
+
config:,
|
103
|
+
cache_dir:,
|
104
|
+
local_files_only:,
|
105
|
+
revision:,
|
106
|
+
model_file_name:
|
107
|
+
}
|
108
|
+
|
109
|
+
model_name = MODEL_CLASS_TO_NAME_MAPPING[self]
|
110
|
+
model_type = MODEL_TYPE_MAPPING[model_name]
|
111
|
+
|
112
|
+
if model_type == MODEL_TYPES[:DecoderOnly]
|
113
|
+
raise Todo
|
114
|
+
|
115
|
+
elsif model_type == MODEL_TYPES[:Seq2Seq] || model_type == MODEL_TYPES[:Vision2Seq]
|
116
|
+
raise Todo
|
117
|
+
|
118
|
+
elsif model_type == MODEL_TYPES[:MaskGeneration]
|
119
|
+
raise Todo
|
120
|
+
|
121
|
+
elsif model_type == MODEL_TYPES[:EncoderDecoder]
|
122
|
+
raise Todo
|
123
|
+
|
124
|
+
else
|
125
|
+
if model_type != MODEL_TYPES[:EncoderOnly]
|
126
|
+
warn "Model type for '#{model_name || config&.model_type}' not found, assuming encoder-only architecture. Please report this."
|
127
|
+
end
|
128
|
+
info = [
|
129
|
+
AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
|
130
|
+
construct_session(pretrained_model_name_or_path, options[:model_file_name] || "model", **options)
|
131
|
+
]
|
132
|
+
end
|
133
|
+
|
134
|
+
new(*info)
|
135
|
+
end
|
136
|
+
|
137
|
+
def self.construct_session(pretrained_model_name_or_path, file_name, **options)
|
138
|
+
model_file_name = "onnx/#{file_name}#{options[:quantized] ? "_quantized" : ""}.onnx"
|
139
|
+
path = Utils::Hub.get_model_file(pretrained_model_name_or_path, model_file_name, true, **options)
|
140
|
+
|
141
|
+
OnnxRuntime::InferenceSession.new(path)
|
142
|
+
end
|
143
|
+
|
144
|
+
def call(model_inputs)
|
145
|
+
@forward.(model_inputs)
|
146
|
+
end
|
147
|
+
|
148
|
+
private
|
149
|
+
|
150
|
+
def encoder_forward(model_inputs)
|
151
|
+
encoder_feeds = {}
|
152
|
+
@session.inputs.each do |input|
|
153
|
+
key = input[:name].to_sym
|
154
|
+
encoder_feeds[key] = model_inputs[key]
|
155
|
+
end
|
156
|
+
if @session.inputs.any? { |v| v[:name] == "token_type_ids" } && !encoder_feeds[:token_type_ids]
|
157
|
+
raise Todo
|
158
|
+
end
|
159
|
+
session_run(@session, encoder_feeds)
|
160
|
+
end
|
161
|
+
|
162
|
+
def session_run(session, inputs)
|
163
|
+
checked_inputs = validate_inputs(session, inputs)
|
164
|
+
begin
|
165
|
+
output = session.run(@output_names, checked_inputs)
|
166
|
+
output = replace_tensors(output)
|
167
|
+
output
|
168
|
+
rescue => e
|
169
|
+
raise e
|
170
|
+
end
|
171
|
+
end
|
172
|
+
|
173
|
+
# TODO
|
174
|
+
def replace_tensors(obj)
|
175
|
+
obj
|
176
|
+
end
|
177
|
+
|
178
|
+
# TODO
|
179
|
+
def validate_inputs(session, inputs)
|
180
|
+
inputs
|
181
|
+
end
|
182
|
+
end
|
183
|
+
|
184
|
+
class BertPreTrainedModel < PreTrainedModel
|
185
|
+
end
|
186
|
+
|
187
|
+
class BertModel < BertPreTrainedModel
|
188
|
+
end
|
189
|
+
|
190
|
+
class BertForSequenceClassification < BertPreTrainedModel
|
191
|
+
def call(model_inputs)
|
192
|
+
SequenceClassifierOutput.new(*super(model_inputs))
|
193
|
+
end
|
194
|
+
end
|
195
|
+
|
196
|
+
class BertForTokenClassification < BertPreTrainedModel
|
197
|
+
def call(model_inputs)
|
198
|
+
TokenClassifierOutput.new(*super(model_inputs))
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
class DistilBertPreTrainedModel < PreTrainedModel
|
203
|
+
end
|
204
|
+
|
205
|
+
class DistilBertModel < DistilBertPreTrainedModel
|
206
|
+
end
|
207
|
+
|
208
|
+
class DistilBertForSequenceClassification < DistilBertPreTrainedModel
|
209
|
+
def call(model_inputs)
|
210
|
+
SequenceClassifierOutput.new(*super(model_inputs))
|
211
|
+
end
|
212
|
+
end
|
213
|
+
|
214
|
+
class DistilBertForQuestionAnswering < DistilBertPreTrainedModel
|
215
|
+
def call(model_inputs)
|
216
|
+
QuestionAnsweringModelOutput.new(*super(model_inputs))
|
217
|
+
end
|
218
|
+
end
|
219
|
+
|
220
|
+
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
|
221
|
+
"bert" => ["BertForSequenceClassification", BertForSequenceClassification],
|
222
|
+
"distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification]
|
223
|
+
}
|
224
|
+
|
225
|
+
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
|
226
|
+
"bert" => ["BertForTokenClassification", BertForTokenClassification]
|
227
|
+
}
|
228
|
+
|
229
|
+
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
|
230
|
+
"distilbert" => ["DistilBertForQuestionAnswering", DistilBertForQuestionAnswering]
|
231
|
+
}
|
232
|
+
|
233
|
+
MODEL_CLASS_TYPE_MAPPING = [
|
234
|
+
[MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
235
|
+
[MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
236
|
+
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
|
237
|
+
]
|
238
|
+
|
239
|
+
MODEL_CLASS_TYPE_MAPPING.each do |mappings, type|
|
240
|
+
mappings.values.each do |name, model|
|
241
|
+
MODEL_TYPE_MAPPING[name] = type
|
242
|
+
MODEL_CLASS_TO_NAME_MAPPING[model] = name
|
243
|
+
MODEL_NAME_TO_CLASS_MAPPING[name] = model
|
244
|
+
end
|
245
|
+
end
|
246
|
+
|
247
|
+
class AutoModel < PretrainedMixin
|
248
|
+
MODEL_CLASS_MAPPINGS = MODEL_CLASS_TYPE_MAPPING.map { |x| x[0] }
|
249
|
+
BASE_IF_FAIL = true
|
250
|
+
end
|
251
|
+
|
252
|
+
class AutoModelForSequenceClassification < PretrainedMixin
|
253
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES]
|
254
|
+
end
|
255
|
+
|
256
|
+
class AutoModelForTokenClassification < PretrainedMixin
|
257
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]
|
258
|
+
end
|
259
|
+
|
260
|
+
class AutoModelForQuestionAnswering < PretrainedMixin
|
261
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]
|
262
|
+
end
|
263
|
+
|
264
|
+
class ModelOutput
|
265
|
+
end
|
266
|
+
|
267
|
+
class SequenceClassifierOutput < ModelOutput
|
268
|
+
attr_reader :logits
|
269
|
+
|
270
|
+
def initialize(logits)
|
271
|
+
super()
|
272
|
+
@logits = logits
|
273
|
+
end
|
274
|
+
end
|
275
|
+
|
276
|
+
class TokenClassifierOutput < ModelOutput
|
277
|
+
attr_reader :logits
|
278
|
+
|
279
|
+
def initialize(logits)
|
280
|
+
super()
|
281
|
+
@logits = logits
|
282
|
+
end
|
283
|
+
end
|
284
|
+
|
285
|
+
class QuestionAnsweringModelOutput < ModelOutput
|
286
|
+
attr_reader :start_logits, :end_logits
|
287
|
+
|
288
|
+
def initialize(start_logits, end_logits)
|
289
|
+
super()
|
290
|
+
@start_logits = start_logits
|
291
|
+
@end_logits = end_logits
|
292
|
+
end
|
293
|
+
end
|
294
|
+
end
|