informers 0.2.0 → 1.0.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +63 -99
- data/lib/informers/configs.rb +48 -0
- data/lib/informers/env.rb +14 -0
- data/lib/informers/model.rb +31 -0
- data/lib/informers/models.rb +294 -0
- data/lib/informers/pipelines.rb +439 -0
- data/lib/informers/tokenizers.rb +141 -0
- data/lib/informers/utils/core.rb +7 -0
- data/lib/informers/utils/hub.rb +240 -0
- data/lib/informers/utils/math.rb +44 -0
- data/lib/informers/utils/tensor.rb +26 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +28 -9
- metadata +21 -41
- data/lib/informers/feature_extraction.rb +0 -59
- data/lib/informers/fill_mask.rb +0 -109
- data/lib/informers/ner.rb +0 -106
- data/lib/informers/question_answering.rb +0 -197
- data/lib/informers/sentiment_analysis.rb +0 -72
- data/lib/informers/text_generation.rb +0 -54
- data/vendor/LICENSE-bert.txt +0 -202
- data/vendor/LICENSE-blingfire.txt +0 -21
- data/vendor/LICENSE-gpt2.txt +0 -24
- data/vendor/LICENSE-roberta.txt +0 -21
- data/vendor/bert_base_cased_tok.bin +0 -0
- data/vendor/bert_base_tok.bin +0 -0
- data/vendor/gpt2.bin +0 -0
- data/vendor/gpt2.i2w +0 -0
- data/vendor/roberta.bin +0 -0
- data/vendor/roberta.i2w +0 -0
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 37ea3d1f5f6e4988731e3c3dd5854ede2fb0211a5dbde18fe70d09a713b12a1c
|
4
|
+
data.tar.gz: ac7b05dc9364e1984d35ccbfc2b7604d8ec9dc76f0f8c1a33f21ba489deed8f4
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: dcd02d4ff94ed472713de26e781cfbf963136eb07da1a9a195c4482c585e1b8ab19875583118f33669b10005bf08f607c09af040b3f53bbed896fb6d19fcf9e4
|
7
|
+
data.tar.gz: 990ea77bf9fdf859354d5532d0a1acefec6576b1d322efb41b27a43aa06b1f0fa2dea81825d0ca3631969bfd0aaf1323091defddc7ed951557370095ab7d209b
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -1,15 +1,10 @@
|
|
1
1
|
# Informers
|
2
2
|
|
3
|
-
:
|
3
|
+
:fire: Fast [transformer](https://github.com/xenova/transformers.js) inference for Ruby
|
4
4
|
|
5
|
-
|
5
|
+
For non-ONNX models, check out [Transformers.rb](https://github.com/ankane/transformers-ruby)
|
6
6
|
|
7
|
-
|
8
|
-
- Question answering
|
9
|
-
- Named-entity recognition
|
10
|
-
- Text generation
|
11
|
-
|
12
|
-
[![Build Status](https://github.com/ankane/informers/workflows/build/badge.svg?branch=master)](https://github.com/ankane/informers/actions)
|
7
|
+
[![Build Status](https://github.com/ankane/informers/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/informers/actions)
|
13
8
|
|
14
9
|
## Installation
|
15
10
|
|
@@ -21,140 +16,111 @@ gem "informers"
|
|
21
16
|
|
22
17
|
## Getting Started
|
23
18
|
|
24
|
-
- [
|
25
|
-
- [
|
26
|
-
- [Named-entity recognition](#named-entity-recognition)
|
27
|
-
- [Text generation](#text-generation)
|
28
|
-
- [Feature extraction](#feature-extraction)
|
29
|
-
- [Fill mask](#fill-mask)
|
19
|
+
- [Models](#models)
|
20
|
+
- [Pipelines](#pipelines)
|
30
21
|
|
31
|
-
|
22
|
+
## Models
|
32
23
|
|
33
|
-
|
24
|
+
### sentence-transformers/all-MiniLM-L6-v2
|
34
25
|
|
35
|
-
|
26
|
+
[Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
|
36
27
|
|
37
28
|
```ruby
|
38
|
-
|
39
|
-
model.predict("This is super cool")
|
40
|
-
```
|
29
|
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
41
30
|
|
42
|
-
|
43
|
-
|
44
|
-
```ruby
|
45
|
-
{label: "positive", score: 0.999855186578301}
|
31
|
+
model = Informers::Model.new("sentence-transformers/all-MiniLM-L6-v2")
|
32
|
+
embeddings = model.embed(sentences)
|
46
33
|
```
|
47
34
|
|
48
|
-
|
35
|
+
For a quantized version, use:
|
49
36
|
|
50
37
|
```ruby
|
51
|
-
model.
|
38
|
+
model = Informers::Model.new("Xenova/all-MiniLM-L6-v2", quantized: true)
|
52
39
|
```
|
53
40
|
|
54
|
-
###
|
55
|
-
|
56
|
-
First, download the [pretrained model](https://github.com/ankane/informers/releases/download/v0.1.0/question-answering.onnx).
|
57
|
-
|
58
|
-
Ask a question with some context
|
41
|
+
### Xenova/multi-qa-MiniLM-L6-cos-v1
|
59
42
|
|
60
|
-
|
61
|
-
model = Informers::QuestionAnswering.new("question-answering.onnx")
|
62
|
-
model.predict(
|
63
|
-
question: "Who invented Ruby?",
|
64
|
-
context: "Ruby is a programming language created by Matz"
|
65
|
-
)
|
66
|
-
```
|
67
|
-
|
68
|
-
This returns
|
43
|
+
[Docs](https://huggingface.co/Xenova/multi-qa-MiniLM-L6-cos-v1)
|
69
44
|
|
70
45
|
```ruby
|
71
|
-
|
46
|
+
query = "How many people live in London?"
|
47
|
+
docs = ["Around 9 Million people live in London", "London is known for its financial district"]
|
48
|
+
|
49
|
+
model = Informers::Model.new("Xenova/multi-qa-MiniLM-L6-cos-v1")
|
50
|
+
query_embedding = model.embed(query)
|
51
|
+
doc_embeddings = model.embed(docs)
|
52
|
+
scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
|
53
|
+
doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
|
72
54
|
```
|
73
55
|
|
74
|
-
###
|
56
|
+
### mixedbread-ai/mxbai-embed-large-v1
|
75
57
|
|
76
|
-
|
77
|
-
|
78
|
-
Get entities
|
58
|
+
[Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
|
79
59
|
|
80
60
|
```ruby
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
{text: "Nat", tag: "person", score: 0.9840519576513487, start: 0, end: 3},
|
90
|
-
{text: "GitHub", tag: "org", score: 0.9426134775785775, start: 13, end: 19},
|
91
|
-
{text: "San Francisco", tag: "location", score: 0.9952414982243061, start: 23, end: 36}
|
61
|
+
def transform_query(query)
|
62
|
+
"Represent this sentence for searching relevant passages: #{query}"
|
63
|
+
end
|
64
|
+
|
65
|
+
docs = [
|
66
|
+
transform_query("puppy"),
|
67
|
+
"The dog is barking",
|
68
|
+
"The cat is purring"
|
92
69
|
]
|
93
|
-
```
|
94
70
|
|
95
|
-
|
71
|
+
model = Informers::Model.new("mixedbread-ai/mxbai-embed-large-v1")
|
72
|
+
embeddings = model.embed(docs)
|
73
|
+
```
|
96
74
|
|
97
|
-
|
75
|
+
## Pipelines
|
98
76
|
|
99
|
-
|
77
|
+
Named-entity recognition
|
100
78
|
|
101
79
|
```ruby
|
102
|
-
|
103
|
-
|
80
|
+
ner = Informers.pipeline("ner")
|
81
|
+
ner.("Ruby is a programming language created by Matz")
|
104
82
|
```
|
105
83
|
|
106
|
-
|
84
|
+
Sentiment analysis
|
107
85
|
|
108
|
-
```
|
109
|
-
|
86
|
+
```ruby
|
87
|
+
classifier = Informers.pipeline("sentiment-analysis")
|
88
|
+
classifier.("We are very happy to show you the 🤗 Transformers library.")
|
110
89
|
```
|
111
90
|
|
112
|
-
|
113
|
-
|
114
|
-
First, export a [pretrained model](tools/export.md).
|
91
|
+
Question answering
|
115
92
|
|
116
93
|
```ruby
|
117
|
-
|
118
|
-
|
94
|
+
qa = Informers.pipeline("question-answering")
|
95
|
+
qa.("Who invented Ruby?", "Ruby is a programming language created by Matz")
|
119
96
|
```
|
120
97
|
|
121
|
-
|
122
|
-
|
123
|
-
First, export a [pretrained model](tools/export.md).
|
98
|
+
Feature extraction
|
124
99
|
|
125
100
|
```ruby
|
126
|
-
|
127
|
-
|
101
|
+
extractor = Informers.pipeline("feature-extraction")
|
102
|
+
extractor.("We are very happy to show you the 🤗 Transformers library.")
|
128
103
|
```
|
129
104
|
|
130
|
-
##
|
131
|
-
|
132
|
-
Task | Description | Contributor | License | Link
|
133
|
-
--- | --- | --- | --- | ---
|
134
|
-
Sentiment analysis | DistilBERT fine-tuned on SST-2 | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english)
|
135
|
-
Question answering | DistilBERT fine-tuned on SQuAD | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-cased-distilled-squad)
|
136
|
-
Named-entity recognition | BERT fine-tuned on CoNLL03 | Bayerische Staatsbibliothek | In-progress | [Link](https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english)
|
137
|
-
Text generation | GPT-2 | OpenAI | [Custom](https://github.com/openai/gpt-2/blob/master/LICENSE) | [Link](https://huggingface.co/gpt2)
|
138
|
-
|
139
|
-
Some models are [quantized](https://medium.com/microsoftazure/faster-and-smaller-quantized-nlp-with-hugging-face-and-onnx-runtime-ec5525473bb7) to make them faster and smaller.
|
140
|
-
|
141
|
-
## Deployment
|
105
|
+
## Credits
|
142
106
|
|
143
|
-
|
107
|
+
This library was ported from [Transformers.js](https://github.com/xenova/transformers.js) and is available under the same license.
|
144
108
|
|
145
|
-
|
146
|
-
trove push sentiment-analysis.onnx
|
147
|
-
```
|
109
|
+
## Upgrading
|
148
110
|
|
149
|
-
|
111
|
+
### 1.0
|
150
112
|
|
151
|
-
|
113
|
+
Task classes have been replaced with the `pipeline` method.
|
152
114
|
|
153
|
-
|
154
|
-
|
155
|
-
|
115
|
+
```ruby
|
116
|
+
# before
|
117
|
+
model = Informers::SentimentAnalysis.new("sentiment-analysis.onnx")
|
118
|
+
model.predict("This is super cool")
|
156
119
|
|
157
|
-
|
120
|
+
# after
|
121
|
+
model = Informers.pipeline("sentiment-analysis")
|
122
|
+
model.("This is super cool")
|
123
|
+
```
|
158
124
|
|
159
125
|
## History
|
160
126
|
|
@@ -175,7 +141,5 @@ To get started with development:
|
|
175
141
|
git clone https://github.com/ankane/informers.git
|
176
142
|
cd informers
|
177
143
|
bundle install
|
178
|
-
|
179
|
-
export MODELS_PATH=path/to/onnx/models
|
180
144
|
bundle exec rake test
|
181
145
|
```
|
@@ -0,0 +1,48 @@
|
|
1
|
+
module Informers
|
2
|
+
class PretrainedConfig
|
3
|
+
attr_reader :model_type, :problem_type, :id2label
|
4
|
+
|
5
|
+
def initialize(config_json)
|
6
|
+
@is_encoder_decoder = false
|
7
|
+
|
8
|
+
@model_type = config_json["model_type"]
|
9
|
+
@problem_type = config_json["problem_type"]
|
10
|
+
@id2label = config_json["id2label"]
|
11
|
+
end
|
12
|
+
|
13
|
+
def [](key)
|
14
|
+
instance_variable_get("@#{key}")
|
15
|
+
end
|
16
|
+
|
17
|
+
def self.from_pretrained(
|
18
|
+
pretrained_model_name_or_path,
|
19
|
+
progress_callback: nil,
|
20
|
+
config: nil,
|
21
|
+
cache_dir: nil,
|
22
|
+
local_files_only: false,
|
23
|
+
revision: "main",
|
24
|
+
**kwargs
|
25
|
+
)
|
26
|
+
data = config || load_config(
|
27
|
+
pretrained_model_name_or_path,
|
28
|
+
progress_callback:,
|
29
|
+
config:,
|
30
|
+
cache_dir:,
|
31
|
+
local_files_only:,
|
32
|
+
revision:
|
33
|
+
)
|
34
|
+
new(data)
|
35
|
+
end
|
36
|
+
|
37
|
+
def self.load_config(pretrained_model_name_or_path, **options)
|
38
|
+
info = Utils::Hub.get_model_json(pretrained_model_name_or_path, "config.json", true, **options)
|
39
|
+
info
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
class AutoConfig
|
44
|
+
def self.from_pretrained(...)
|
45
|
+
PretrainedConfig.from_pretrained(...)
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Informers
|
2
|
+
CACHE_HOME = ENV.fetch("XDG_CACHE_HOME", File.join(ENV.fetch("HOME"), ".cache"))
|
3
|
+
DEFAULT_CACHE_DIR = File.expand_path(File.join(CACHE_HOME, "informers"))
|
4
|
+
|
5
|
+
class << self
|
6
|
+
attr_accessor :allow_remote_models, :remote_host, :remote_path_template, :cache_dir
|
7
|
+
end
|
8
|
+
|
9
|
+
self.allow_remote_models = ENV["INFORMERS_OFFLINE"].to_s.empty?
|
10
|
+
self.remote_host = "https://huggingface.co/"
|
11
|
+
self.remote_path_template = "{model}/resolve/{revision}/"
|
12
|
+
|
13
|
+
self.cache_dir = DEFAULT_CACHE_DIR
|
14
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
module Informers
|
2
|
+
class Model
|
3
|
+
def initialize(model_id, quantized: false)
|
4
|
+
@model_id = model_id
|
5
|
+
@model = Informers.pipeline("feature-extraction", model_id, quantized: quantized)
|
6
|
+
|
7
|
+
# TODO better pattern
|
8
|
+
if model_id == "sentence-transformers/all-MiniLM-L6-v2"
|
9
|
+
@model.instance_variable_get(:@model).instance_variable_set(:@output_names, ["sentence_embedding"])
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
def embed(texts)
|
14
|
+
is_batched = texts.is_a?(Array)
|
15
|
+
texts = [texts] unless is_batched
|
16
|
+
|
17
|
+
case @model_id
|
18
|
+
when "sentence-transformers/all-MiniLM-L6-v2"
|
19
|
+
output = @model.(texts)
|
20
|
+
when "Xenova/all-MiniLM-L6-v2", "Xenova/multi-qa-MiniLM-L6-cos-v1"
|
21
|
+
output = @model.(texts, pooling: "mean", normalize: true)
|
22
|
+
when "mixedbread-ai/mxbai-embed-large-v1"
|
23
|
+
output = @model.(texts, pooling: "cls")
|
24
|
+
else
|
25
|
+
raise Error, "model not supported: #{@model_id}"
|
26
|
+
end
|
27
|
+
|
28
|
+
is_batched ? output : output[0]
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
@@ -0,0 +1,294 @@
|
|
1
|
+
module Informers
|
2
|
+
MODEL_TYPES = {
|
3
|
+
EncoderOnly: 0,
|
4
|
+
EncoderDecoder: 1,
|
5
|
+
Seq2Seq: 2,
|
6
|
+
Vision2Seq: 3,
|
7
|
+
DecoderOnly: 4,
|
8
|
+
MaskGeneration: 5
|
9
|
+
}
|
10
|
+
|
11
|
+
# NOTE: These will be populated fully later
|
12
|
+
MODEL_TYPE_MAPPING = {}
|
13
|
+
MODEL_NAME_TO_CLASS_MAPPING = {}
|
14
|
+
MODEL_CLASS_TO_NAME_MAPPING = {}
|
15
|
+
|
16
|
+
class PretrainedMixin
|
17
|
+
def self.from_pretrained(
|
18
|
+
pretrained_model_name_or_path,
|
19
|
+
quantized: true,
|
20
|
+
progress_callback: nil,
|
21
|
+
config: nil,
|
22
|
+
cache_dir: nil,
|
23
|
+
local_files_only: false,
|
24
|
+
revision: "main",
|
25
|
+
model_file_name: nil
|
26
|
+
)
|
27
|
+
options = {
|
28
|
+
quantized:,
|
29
|
+
progress_callback:,
|
30
|
+
config:,
|
31
|
+
cache_dir:,
|
32
|
+
local_files_only:,
|
33
|
+
revision:,
|
34
|
+
model_file_name:
|
35
|
+
}
|
36
|
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **options)
|
37
|
+
if options[:config].nil?
|
38
|
+
# If no config was passed, reuse this config for future processing
|
39
|
+
options[:config] = config
|
40
|
+
end
|
41
|
+
|
42
|
+
if !const_defined?(:MODEL_CLASS_MAPPINGS)
|
43
|
+
raise Error, "`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: #{name}"
|
44
|
+
end
|
45
|
+
|
46
|
+
const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|
|
47
|
+
model_info = model_class_mapping[config.model_type]
|
48
|
+
if !model_info
|
49
|
+
next # Item not found in this mapping
|
50
|
+
end
|
51
|
+
return model_info[1].from_pretrained(pretrained_model_name_or_path, **options)
|
52
|
+
end
|
53
|
+
|
54
|
+
if const_defined?(:BASE_IF_FAIL)
|
55
|
+
warn "Unknown model class #{config.model_type.inspect}, attempting to construct from base class."
|
56
|
+
PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)
|
57
|
+
else
|
58
|
+
raise Error, "Unsupported model type: #{config.model_type}"
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
class PreTrainedModel
|
64
|
+
attr_reader :config
|
65
|
+
|
66
|
+
def initialize(config, session)
|
67
|
+
super()
|
68
|
+
|
69
|
+
@config = config
|
70
|
+
@session = session
|
71
|
+
|
72
|
+
@output_names = nil
|
73
|
+
|
74
|
+
model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
|
75
|
+
model_type = MODEL_TYPE_MAPPING[model_name]
|
76
|
+
|
77
|
+
case model_type
|
78
|
+
when MODEL_TYPES[:DecoderOnly]
|
79
|
+
raise Todo
|
80
|
+
when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
|
81
|
+
raise Todo
|
82
|
+
when MODEL_TYPES[:EncoderDecoder]
|
83
|
+
raise Todo
|
84
|
+
else
|
85
|
+
@forward = method(:encoder_forward)
|
86
|
+
end
|
87
|
+
end
|
88
|
+
|
89
|
+
def self.from_pretrained(
|
90
|
+
pretrained_model_name_or_path,
|
91
|
+
quantized: true,
|
92
|
+
progress_callback: nil,
|
93
|
+
config: nil,
|
94
|
+
cache_dir: nil,
|
95
|
+
local_files_only: false,
|
96
|
+
revision: "main",
|
97
|
+
model_file_name: nil
|
98
|
+
)
|
99
|
+
options = {
|
100
|
+
quantized:,
|
101
|
+
progress_callback:,
|
102
|
+
config:,
|
103
|
+
cache_dir:,
|
104
|
+
local_files_only:,
|
105
|
+
revision:,
|
106
|
+
model_file_name:
|
107
|
+
}
|
108
|
+
|
109
|
+
model_name = MODEL_CLASS_TO_NAME_MAPPING[self]
|
110
|
+
model_type = MODEL_TYPE_MAPPING[model_name]
|
111
|
+
|
112
|
+
if model_type == MODEL_TYPES[:DecoderOnly]
|
113
|
+
raise Todo
|
114
|
+
|
115
|
+
elsif model_type == MODEL_TYPES[:Seq2Seq] || model_type == MODEL_TYPES[:Vision2Seq]
|
116
|
+
raise Todo
|
117
|
+
|
118
|
+
elsif model_type == MODEL_TYPES[:MaskGeneration]
|
119
|
+
raise Todo
|
120
|
+
|
121
|
+
elsif model_type == MODEL_TYPES[:EncoderDecoder]
|
122
|
+
raise Todo
|
123
|
+
|
124
|
+
else
|
125
|
+
if model_type != MODEL_TYPES[:EncoderOnly]
|
126
|
+
warn "Model type for '#{model_name || config&.model_type}' not found, assuming encoder-only architecture. Please report this."
|
127
|
+
end
|
128
|
+
info = [
|
129
|
+
AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
|
130
|
+
construct_session(pretrained_model_name_or_path, options[:model_file_name] || "model", **options)
|
131
|
+
]
|
132
|
+
end
|
133
|
+
|
134
|
+
new(*info)
|
135
|
+
end
|
136
|
+
|
137
|
+
def self.construct_session(pretrained_model_name_or_path, file_name, **options)
|
138
|
+
model_file_name = "onnx/#{file_name}#{options[:quantized] ? "_quantized" : ""}.onnx"
|
139
|
+
path = Utils::Hub.get_model_file(pretrained_model_name_or_path, model_file_name, true, **options)
|
140
|
+
|
141
|
+
OnnxRuntime::InferenceSession.new(path)
|
142
|
+
end
|
143
|
+
|
144
|
+
def call(model_inputs)
|
145
|
+
@forward.(model_inputs)
|
146
|
+
end
|
147
|
+
|
148
|
+
private
|
149
|
+
|
150
|
+
def encoder_forward(model_inputs)
|
151
|
+
encoder_feeds = {}
|
152
|
+
@session.inputs.each do |input|
|
153
|
+
key = input[:name].to_sym
|
154
|
+
encoder_feeds[key] = model_inputs[key]
|
155
|
+
end
|
156
|
+
if @session.inputs.any? { |v| v[:name] == "token_type_ids" } && !encoder_feeds[:token_type_ids]
|
157
|
+
raise Todo
|
158
|
+
end
|
159
|
+
session_run(@session, encoder_feeds)
|
160
|
+
end
|
161
|
+
|
162
|
+
def session_run(session, inputs)
|
163
|
+
checked_inputs = validate_inputs(session, inputs)
|
164
|
+
begin
|
165
|
+
output = session.run(@output_names, checked_inputs)
|
166
|
+
output = replace_tensors(output)
|
167
|
+
output
|
168
|
+
rescue => e
|
169
|
+
raise e
|
170
|
+
end
|
171
|
+
end
|
172
|
+
|
173
|
+
# TODO
|
174
|
+
def replace_tensors(obj)
|
175
|
+
obj
|
176
|
+
end
|
177
|
+
|
178
|
+
# TODO
|
179
|
+
def validate_inputs(session, inputs)
|
180
|
+
inputs
|
181
|
+
end
|
182
|
+
end
|
183
|
+
|
184
|
+
class BertPreTrainedModel < PreTrainedModel
|
185
|
+
end
|
186
|
+
|
187
|
+
class BertModel < BertPreTrainedModel
|
188
|
+
end
|
189
|
+
|
190
|
+
class BertForSequenceClassification < BertPreTrainedModel
|
191
|
+
def call(model_inputs)
|
192
|
+
SequenceClassifierOutput.new(*super(model_inputs))
|
193
|
+
end
|
194
|
+
end
|
195
|
+
|
196
|
+
class BertForTokenClassification < BertPreTrainedModel
|
197
|
+
def call(model_inputs)
|
198
|
+
TokenClassifierOutput.new(*super(model_inputs))
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
class DistilBertPreTrainedModel < PreTrainedModel
|
203
|
+
end
|
204
|
+
|
205
|
+
class DistilBertModel < DistilBertPreTrainedModel
|
206
|
+
end
|
207
|
+
|
208
|
+
class DistilBertForSequenceClassification < DistilBertPreTrainedModel
|
209
|
+
def call(model_inputs)
|
210
|
+
SequenceClassifierOutput.new(*super(model_inputs))
|
211
|
+
end
|
212
|
+
end
|
213
|
+
|
214
|
+
class DistilBertForQuestionAnswering < DistilBertPreTrainedModel
|
215
|
+
def call(model_inputs)
|
216
|
+
QuestionAnsweringModelOutput.new(*super(model_inputs))
|
217
|
+
end
|
218
|
+
end
|
219
|
+
|
220
|
+
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
|
221
|
+
"bert" => ["BertForSequenceClassification", BertForSequenceClassification],
|
222
|
+
"distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification]
|
223
|
+
}
|
224
|
+
|
225
|
+
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
|
226
|
+
"bert" => ["BertForTokenClassification", BertForTokenClassification]
|
227
|
+
}
|
228
|
+
|
229
|
+
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
|
230
|
+
"distilbert" => ["DistilBertForQuestionAnswering", DistilBertForQuestionAnswering]
|
231
|
+
}
|
232
|
+
|
233
|
+
MODEL_CLASS_TYPE_MAPPING = [
|
234
|
+
[MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
235
|
+
[MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
236
|
+
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
|
237
|
+
]
|
238
|
+
|
239
|
+
MODEL_CLASS_TYPE_MAPPING.each do |mappings, type|
|
240
|
+
mappings.values.each do |name, model|
|
241
|
+
MODEL_TYPE_MAPPING[name] = type
|
242
|
+
MODEL_CLASS_TO_NAME_MAPPING[model] = name
|
243
|
+
MODEL_NAME_TO_CLASS_MAPPING[name] = model
|
244
|
+
end
|
245
|
+
end
|
246
|
+
|
247
|
+
class AutoModel < PretrainedMixin
|
248
|
+
MODEL_CLASS_MAPPINGS = MODEL_CLASS_TYPE_MAPPING.map { |x| x[0] }
|
249
|
+
BASE_IF_FAIL = true
|
250
|
+
end
|
251
|
+
|
252
|
+
class AutoModelForSequenceClassification < PretrainedMixin
|
253
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES]
|
254
|
+
end
|
255
|
+
|
256
|
+
class AutoModelForTokenClassification < PretrainedMixin
|
257
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]
|
258
|
+
end
|
259
|
+
|
260
|
+
class AutoModelForQuestionAnswering < PretrainedMixin
|
261
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]
|
262
|
+
end
|
263
|
+
|
264
|
+
class ModelOutput
|
265
|
+
end
|
266
|
+
|
267
|
+
class SequenceClassifierOutput < ModelOutput
|
268
|
+
attr_reader :logits
|
269
|
+
|
270
|
+
def initialize(logits)
|
271
|
+
super()
|
272
|
+
@logits = logits
|
273
|
+
end
|
274
|
+
end
|
275
|
+
|
276
|
+
class TokenClassifierOutput < ModelOutput
|
277
|
+
attr_reader :logits
|
278
|
+
|
279
|
+
def initialize(logits)
|
280
|
+
super()
|
281
|
+
@logits = logits
|
282
|
+
end
|
283
|
+
end
|
284
|
+
|
285
|
+
class QuestionAnsweringModelOutput < ModelOutput
|
286
|
+
attr_reader :start_logits, :end_logits
|
287
|
+
|
288
|
+
def initialize(start_logits, end_logits)
|
289
|
+
super()
|
290
|
+
@start_logits = start_logits
|
291
|
+
@end_logits = end_logits
|
292
|
+
end
|
293
|
+
end
|
294
|
+
end
|