informers 0.2.0 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +63 -99
- data/lib/informers/configs.rb +48 -0
- data/lib/informers/env.rb +14 -0
- data/lib/informers/model.rb +31 -0
- data/lib/informers/models.rb +294 -0
- data/lib/informers/pipelines.rb +439 -0
- data/lib/informers/tokenizers.rb +141 -0
- data/lib/informers/utils/core.rb +7 -0
- data/lib/informers/utils/hub.rb +240 -0
- data/lib/informers/utils/math.rb +44 -0
- data/lib/informers/utils/tensor.rb +26 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +28 -9
- metadata +21 -41
- data/lib/informers/feature_extraction.rb +0 -59
- data/lib/informers/fill_mask.rb +0 -109
- data/lib/informers/ner.rb +0 -106
- data/lib/informers/question_answering.rb +0 -197
- data/lib/informers/sentiment_analysis.rb +0 -72
- data/lib/informers/text_generation.rb +0 -54
- data/vendor/LICENSE-bert.txt +0 -202
- data/vendor/LICENSE-blingfire.txt +0 -21
- data/vendor/LICENSE-gpt2.txt +0 -24
- data/vendor/LICENSE-roberta.txt +0 -21
- data/vendor/bert_base_cased_tok.bin +0 -0
- data/vendor/bert_base_tok.bin +0 -0
- data/vendor/gpt2.bin +0 -0
- data/vendor/gpt2.i2w +0 -0
- data/vendor/roberta.bin +0 -0
- data/vendor/roberta.i2w +0 -0
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 37ea3d1f5f6e4988731e3c3dd5854ede2fb0211a5dbde18fe70d09a713b12a1c
|
4
|
+
data.tar.gz: ac7b05dc9364e1984d35ccbfc2b7604d8ec9dc76f0f8c1a33f21ba489deed8f4
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: dcd02d4ff94ed472713de26e781cfbf963136eb07da1a9a195c4482c585e1b8ab19875583118f33669b10005bf08f607c09af040b3f53bbed896fb6d19fcf9e4
|
7
|
+
data.tar.gz: 990ea77bf9fdf859354d5532d0a1acefec6576b1d322efb41b27a43aa06b1f0fa2dea81825d0ca3631969bfd0aaf1323091defddc7ed951557370095ab7d209b
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -1,15 +1,10 @@
|
|
1
1
|
# Informers
|
2
2
|
|
3
|
-
:
|
3
|
+
:fire: Fast [transformer](https://github.com/xenova/transformers.js) inference for Ruby
|
4
4
|
|
5
|
-
|
5
|
+
For non-ONNX models, check out [Transformers.rb](https://github.com/ankane/transformers-ruby)
|
6
6
|
|
7
|
-
|
8
|
-
- Question answering
|
9
|
-
- Named-entity recognition
|
10
|
-
- Text generation
|
11
|
-
|
12
|
-
[](https://github.com/ankane/informers/actions)
|
7
|
+
[](https://github.com/ankane/informers/actions)
|
13
8
|
|
14
9
|
## Installation
|
15
10
|
|
@@ -21,140 +16,111 @@ gem "informers"
|
|
21
16
|
|
22
17
|
## Getting Started
|
23
18
|
|
24
|
-
- [
|
25
|
-
- [
|
26
|
-
- [Named-entity recognition](#named-entity-recognition)
|
27
|
-
- [Text generation](#text-generation)
|
28
|
-
- [Feature extraction](#feature-extraction)
|
29
|
-
- [Fill mask](#fill-mask)
|
19
|
+
- [Models](#models)
|
20
|
+
- [Pipelines](#pipelines)
|
30
21
|
|
31
|
-
|
22
|
+
## Models
|
32
23
|
|
33
|
-
|
24
|
+
### sentence-transformers/all-MiniLM-L6-v2
|
34
25
|
|
35
|
-
|
26
|
+
[Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
|
36
27
|
|
37
28
|
```ruby
|
38
|
-
|
39
|
-
model.predict("This is super cool")
|
40
|
-
```
|
29
|
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
41
30
|
|
42
|
-
|
43
|
-
|
44
|
-
```ruby
|
45
|
-
{label: "positive", score: 0.999855186578301}
|
31
|
+
model = Informers::Model.new("sentence-transformers/all-MiniLM-L6-v2")
|
32
|
+
embeddings = model.embed(sentences)
|
46
33
|
```
|
47
34
|
|
48
|
-
|
35
|
+
For a quantized version, use:
|
49
36
|
|
50
37
|
```ruby
|
51
|
-
model.
|
38
|
+
model = Informers::Model.new("Xenova/all-MiniLM-L6-v2", quantized: true)
|
52
39
|
```
|
53
40
|
|
54
|
-
###
|
55
|
-
|
56
|
-
First, download the [pretrained model](https://github.com/ankane/informers/releases/download/v0.1.0/question-answering.onnx).
|
57
|
-
|
58
|
-
Ask a question with some context
|
41
|
+
### Xenova/multi-qa-MiniLM-L6-cos-v1
|
59
42
|
|
60
|
-
|
61
|
-
model = Informers::QuestionAnswering.new("question-answering.onnx")
|
62
|
-
model.predict(
|
63
|
-
question: "Who invented Ruby?",
|
64
|
-
context: "Ruby is a programming language created by Matz"
|
65
|
-
)
|
66
|
-
```
|
67
|
-
|
68
|
-
This returns
|
43
|
+
[Docs](https://huggingface.co/Xenova/multi-qa-MiniLM-L6-cos-v1)
|
69
44
|
|
70
45
|
```ruby
|
71
|
-
|
46
|
+
query = "How many people live in London?"
|
47
|
+
docs = ["Around 9 Million people live in London", "London is known for its financial district"]
|
48
|
+
|
49
|
+
model = Informers::Model.new("Xenova/multi-qa-MiniLM-L6-cos-v1")
|
50
|
+
query_embedding = model.embed(query)
|
51
|
+
doc_embeddings = model.embed(docs)
|
52
|
+
scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
|
53
|
+
doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
|
72
54
|
```
|
73
55
|
|
74
|
-
###
|
56
|
+
### mixedbread-ai/mxbai-embed-large-v1
|
75
57
|
|
76
|
-
|
77
|
-
|
78
|
-
Get entities
|
58
|
+
[Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
|
79
59
|
|
80
60
|
```ruby
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
{text: "Nat", tag: "person", score: 0.9840519576513487, start: 0, end: 3},
|
90
|
-
{text: "GitHub", tag: "org", score: 0.9426134775785775, start: 13, end: 19},
|
91
|
-
{text: "San Francisco", tag: "location", score: 0.9952414982243061, start: 23, end: 36}
|
61
|
+
def transform_query(query)
|
62
|
+
"Represent this sentence for searching relevant passages: #{query}"
|
63
|
+
end
|
64
|
+
|
65
|
+
docs = [
|
66
|
+
transform_query("puppy"),
|
67
|
+
"The dog is barking",
|
68
|
+
"The cat is purring"
|
92
69
|
]
|
93
|
-
```
|
94
70
|
|
95
|
-
|
71
|
+
model = Informers::Model.new("mixedbread-ai/mxbai-embed-large-v1")
|
72
|
+
embeddings = model.embed(docs)
|
73
|
+
```
|
96
74
|
|
97
|
-
|
75
|
+
## Pipelines
|
98
76
|
|
99
|
-
|
77
|
+
Named-entity recognition
|
100
78
|
|
101
79
|
```ruby
|
102
|
-
|
103
|
-
|
80
|
+
ner = Informers.pipeline("ner")
|
81
|
+
ner.("Ruby is a programming language created by Matz")
|
104
82
|
```
|
105
83
|
|
106
|
-
|
84
|
+
Sentiment analysis
|
107
85
|
|
108
|
-
```
|
109
|
-
|
86
|
+
```ruby
|
87
|
+
classifier = Informers.pipeline("sentiment-analysis")
|
88
|
+
classifier.("We are very happy to show you the 🤗 Transformers library.")
|
110
89
|
```
|
111
90
|
|
112
|
-
|
113
|
-
|
114
|
-
First, export a [pretrained model](tools/export.md).
|
91
|
+
Question answering
|
115
92
|
|
116
93
|
```ruby
|
117
|
-
|
118
|
-
|
94
|
+
qa = Informers.pipeline("question-answering")
|
95
|
+
qa.("Who invented Ruby?", "Ruby is a programming language created by Matz")
|
119
96
|
```
|
120
97
|
|
121
|
-
|
122
|
-
|
123
|
-
First, export a [pretrained model](tools/export.md).
|
98
|
+
Feature extraction
|
124
99
|
|
125
100
|
```ruby
|
126
|
-
|
127
|
-
|
101
|
+
extractor = Informers.pipeline("feature-extraction")
|
102
|
+
extractor.("We are very happy to show you the 🤗 Transformers library.")
|
128
103
|
```
|
129
104
|
|
130
|
-
##
|
131
|
-
|
132
|
-
Task | Description | Contributor | License | Link
|
133
|
-
--- | --- | --- | --- | ---
|
134
|
-
Sentiment analysis | DistilBERT fine-tuned on SST-2 | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english)
|
135
|
-
Question answering | DistilBERT fine-tuned on SQuAD | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-cased-distilled-squad)
|
136
|
-
Named-entity recognition | BERT fine-tuned on CoNLL03 | Bayerische Staatsbibliothek | In-progress | [Link](https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english)
|
137
|
-
Text generation | GPT-2 | OpenAI | [Custom](https://github.com/openai/gpt-2/blob/master/LICENSE) | [Link](https://huggingface.co/gpt2)
|
138
|
-
|
139
|
-
Some models are [quantized](https://medium.com/microsoftazure/faster-and-smaller-quantized-nlp-with-hugging-face-and-onnx-runtime-ec5525473bb7) to make them faster and smaller.
|
140
|
-
|
141
|
-
## Deployment
|
105
|
+
## Credits
|
142
106
|
|
143
|
-
|
107
|
+
This library was ported from [Transformers.js](https://github.com/xenova/transformers.js) and is available under the same license.
|
144
108
|
|
145
|
-
|
146
|
-
trove push sentiment-analysis.onnx
|
147
|
-
```
|
109
|
+
## Upgrading
|
148
110
|
|
149
|
-
|
111
|
+
### 1.0
|
150
112
|
|
151
|
-
|
113
|
+
Task classes have been replaced with the `pipeline` method.
|
152
114
|
|
153
|
-
|
154
|
-
|
155
|
-
|
115
|
+
```ruby
|
116
|
+
# before
|
117
|
+
model = Informers::SentimentAnalysis.new("sentiment-analysis.onnx")
|
118
|
+
model.predict("This is super cool")
|
156
119
|
|
157
|
-
|
120
|
+
# after
|
121
|
+
model = Informers.pipeline("sentiment-analysis")
|
122
|
+
model.("This is super cool")
|
123
|
+
```
|
158
124
|
|
159
125
|
## History
|
160
126
|
|
@@ -175,7 +141,5 @@ To get started with development:
|
|
175
141
|
git clone https://github.com/ankane/informers.git
|
176
142
|
cd informers
|
177
143
|
bundle install
|
178
|
-
|
179
|
-
export MODELS_PATH=path/to/onnx/models
|
180
144
|
bundle exec rake test
|
181
145
|
```
|
@@ -0,0 +1,48 @@
|
|
1
|
+
module Informers
|
2
|
+
class PretrainedConfig
|
3
|
+
attr_reader :model_type, :problem_type, :id2label
|
4
|
+
|
5
|
+
def initialize(config_json)
|
6
|
+
@is_encoder_decoder = false
|
7
|
+
|
8
|
+
@model_type = config_json["model_type"]
|
9
|
+
@problem_type = config_json["problem_type"]
|
10
|
+
@id2label = config_json["id2label"]
|
11
|
+
end
|
12
|
+
|
13
|
+
def [](key)
|
14
|
+
instance_variable_get("@#{key}")
|
15
|
+
end
|
16
|
+
|
17
|
+
def self.from_pretrained(
|
18
|
+
pretrained_model_name_or_path,
|
19
|
+
progress_callback: nil,
|
20
|
+
config: nil,
|
21
|
+
cache_dir: nil,
|
22
|
+
local_files_only: false,
|
23
|
+
revision: "main",
|
24
|
+
**kwargs
|
25
|
+
)
|
26
|
+
data = config || load_config(
|
27
|
+
pretrained_model_name_or_path,
|
28
|
+
progress_callback:,
|
29
|
+
config:,
|
30
|
+
cache_dir:,
|
31
|
+
local_files_only:,
|
32
|
+
revision:
|
33
|
+
)
|
34
|
+
new(data)
|
35
|
+
end
|
36
|
+
|
37
|
+
def self.load_config(pretrained_model_name_or_path, **options)
|
38
|
+
info = Utils::Hub.get_model_json(pretrained_model_name_or_path, "config.json", true, **options)
|
39
|
+
info
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
class AutoConfig
|
44
|
+
def self.from_pretrained(...)
|
45
|
+
PretrainedConfig.from_pretrained(...)
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Informers
|
2
|
+
CACHE_HOME = ENV.fetch("XDG_CACHE_HOME", File.join(ENV.fetch("HOME"), ".cache"))
|
3
|
+
DEFAULT_CACHE_DIR = File.expand_path(File.join(CACHE_HOME, "informers"))
|
4
|
+
|
5
|
+
class << self
|
6
|
+
attr_accessor :allow_remote_models, :remote_host, :remote_path_template, :cache_dir
|
7
|
+
end
|
8
|
+
|
9
|
+
self.allow_remote_models = ENV["INFORMERS_OFFLINE"].to_s.empty?
|
10
|
+
self.remote_host = "https://huggingface.co/"
|
11
|
+
self.remote_path_template = "{model}/resolve/{revision}/"
|
12
|
+
|
13
|
+
self.cache_dir = DEFAULT_CACHE_DIR
|
14
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
module Informers
|
2
|
+
class Model
|
3
|
+
def initialize(model_id, quantized: false)
|
4
|
+
@model_id = model_id
|
5
|
+
@model = Informers.pipeline("feature-extraction", model_id, quantized: quantized)
|
6
|
+
|
7
|
+
# TODO better pattern
|
8
|
+
if model_id == "sentence-transformers/all-MiniLM-L6-v2"
|
9
|
+
@model.instance_variable_get(:@model).instance_variable_set(:@output_names, ["sentence_embedding"])
|
10
|
+
end
|
11
|
+
end
|
12
|
+
|
13
|
+
def embed(texts)
|
14
|
+
is_batched = texts.is_a?(Array)
|
15
|
+
texts = [texts] unless is_batched
|
16
|
+
|
17
|
+
case @model_id
|
18
|
+
when "sentence-transformers/all-MiniLM-L6-v2"
|
19
|
+
output = @model.(texts)
|
20
|
+
when "Xenova/all-MiniLM-L6-v2", "Xenova/multi-qa-MiniLM-L6-cos-v1"
|
21
|
+
output = @model.(texts, pooling: "mean", normalize: true)
|
22
|
+
when "mixedbread-ai/mxbai-embed-large-v1"
|
23
|
+
output = @model.(texts, pooling: "cls")
|
24
|
+
else
|
25
|
+
raise Error, "model not supported: #{@model_id}"
|
26
|
+
end
|
27
|
+
|
28
|
+
is_batched ? output : output[0]
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
@@ -0,0 +1,294 @@
|
|
1
|
+
module Informers
|
2
|
+
MODEL_TYPES = {
|
3
|
+
EncoderOnly: 0,
|
4
|
+
EncoderDecoder: 1,
|
5
|
+
Seq2Seq: 2,
|
6
|
+
Vision2Seq: 3,
|
7
|
+
DecoderOnly: 4,
|
8
|
+
MaskGeneration: 5
|
9
|
+
}
|
10
|
+
|
11
|
+
# NOTE: These will be populated fully later
|
12
|
+
MODEL_TYPE_MAPPING = {}
|
13
|
+
MODEL_NAME_TO_CLASS_MAPPING = {}
|
14
|
+
MODEL_CLASS_TO_NAME_MAPPING = {}
|
15
|
+
|
16
|
+
class PretrainedMixin
|
17
|
+
def self.from_pretrained(
|
18
|
+
pretrained_model_name_or_path,
|
19
|
+
quantized: true,
|
20
|
+
progress_callback: nil,
|
21
|
+
config: nil,
|
22
|
+
cache_dir: nil,
|
23
|
+
local_files_only: false,
|
24
|
+
revision: "main",
|
25
|
+
model_file_name: nil
|
26
|
+
)
|
27
|
+
options = {
|
28
|
+
quantized:,
|
29
|
+
progress_callback:,
|
30
|
+
config:,
|
31
|
+
cache_dir:,
|
32
|
+
local_files_only:,
|
33
|
+
revision:,
|
34
|
+
model_file_name:
|
35
|
+
}
|
36
|
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **options)
|
37
|
+
if options[:config].nil?
|
38
|
+
# If no config was passed, reuse this config for future processing
|
39
|
+
options[:config] = config
|
40
|
+
end
|
41
|
+
|
42
|
+
if !const_defined?(:MODEL_CLASS_MAPPINGS)
|
43
|
+
raise Error, "`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: #{name}"
|
44
|
+
end
|
45
|
+
|
46
|
+
const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|
|
47
|
+
model_info = model_class_mapping[config.model_type]
|
48
|
+
if !model_info
|
49
|
+
next # Item not found in this mapping
|
50
|
+
end
|
51
|
+
return model_info[1].from_pretrained(pretrained_model_name_or_path, **options)
|
52
|
+
end
|
53
|
+
|
54
|
+
if const_defined?(:BASE_IF_FAIL)
|
55
|
+
warn "Unknown model class #{config.model_type.inspect}, attempting to construct from base class."
|
56
|
+
PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)
|
57
|
+
else
|
58
|
+
raise Error, "Unsupported model type: #{config.model_type}"
|
59
|
+
end
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
class PreTrainedModel
|
64
|
+
attr_reader :config
|
65
|
+
|
66
|
+
def initialize(config, session)
|
67
|
+
super()
|
68
|
+
|
69
|
+
@config = config
|
70
|
+
@session = session
|
71
|
+
|
72
|
+
@output_names = nil
|
73
|
+
|
74
|
+
model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
|
75
|
+
model_type = MODEL_TYPE_MAPPING[model_name]
|
76
|
+
|
77
|
+
case model_type
|
78
|
+
when MODEL_TYPES[:DecoderOnly]
|
79
|
+
raise Todo
|
80
|
+
when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
|
81
|
+
raise Todo
|
82
|
+
when MODEL_TYPES[:EncoderDecoder]
|
83
|
+
raise Todo
|
84
|
+
else
|
85
|
+
@forward = method(:encoder_forward)
|
86
|
+
end
|
87
|
+
end
|
88
|
+
|
89
|
+
def self.from_pretrained(
|
90
|
+
pretrained_model_name_or_path,
|
91
|
+
quantized: true,
|
92
|
+
progress_callback: nil,
|
93
|
+
config: nil,
|
94
|
+
cache_dir: nil,
|
95
|
+
local_files_only: false,
|
96
|
+
revision: "main",
|
97
|
+
model_file_name: nil
|
98
|
+
)
|
99
|
+
options = {
|
100
|
+
quantized:,
|
101
|
+
progress_callback:,
|
102
|
+
config:,
|
103
|
+
cache_dir:,
|
104
|
+
local_files_only:,
|
105
|
+
revision:,
|
106
|
+
model_file_name:
|
107
|
+
}
|
108
|
+
|
109
|
+
model_name = MODEL_CLASS_TO_NAME_MAPPING[self]
|
110
|
+
model_type = MODEL_TYPE_MAPPING[model_name]
|
111
|
+
|
112
|
+
if model_type == MODEL_TYPES[:DecoderOnly]
|
113
|
+
raise Todo
|
114
|
+
|
115
|
+
elsif model_type == MODEL_TYPES[:Seq2Seq] || model_type == MODEL_TYPES[:Vision2Seq]
|
116
|
+
raise Todo
|
117
|
+
|
118
|
+
elsif model_type == MODEL_TYPES[:MaskGeneration]
|
119
|
+
raise Todo
|
120
|
+
|
121
|
+
elsif model_type == MODEL_TYPES[:EncoderDecoder]
|
122
|
+
raise Todo
|
123
|
+
|
124
|
+
else
|
125
|
+
if model_type != MODEL_TYPES[:EncoderOnly]
|
126
|
+
warn "Model type for '#{model_name || config&.model_type}' not found, assuming encoder-only architecture. Please report this."
|
127
|
+
end
|
128
|
+
info = [
|
129
|
+
AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
|
130
|
+
construct_session(pretrained_model_name_or_path, options[:model_file_name] || "model", **options)
|
131
|
+
]
|
132
|
+
end
|
133
|
+
|
134
|
+
new(*info)
|
135
|
+
end
|
136
|
+
|
137
|
+
def self.construct_session(pretrained_model_name_or_path, file_name, **options)
|
138
|
+
model_file_name = "onnx/#{file_name}#{options[:quantized] ? "_quantized" : ""}.onnx"
|
139
|
+
path = Utils::Hub.get_model_file(pretrained_model_name_or_path, model_file_name, true, **options)
|
140
|
+
|
141
|
+
OnnxRuntime::InferenceSession.new(path)
|
142
|
+
end
|
143
|
+
|
144
|
+
def call(model_inputs)
|
145
|
+
@forward.(model_inputs)
|
146
|
+
end
|
147
|
+
|
148
|
+
private
|
149
|
+
|
150
|
+
def encoder_forward(model_inputs)
|
151
|
+
encoder_feeds = {}
|
152
|
+
@session.inputs.each do |input|
|
153
|
+
key = input[:name].to_sym
|
154
|
+
encoder_feeds[key] = model_inputs[key]
|
155
|
+
end
|
156
|
+
if @session.inputs.any? { |v| v[:name] == "token_type_ids" } && !encoder_feeds[:token_type_ids]
|
157
|
+
raise Todo
|
158
|
+
end
|
159
|
+
session_run(@session, encoder_feeds)
|
160
|
+
end
|
161
|
+
|
162
|
+
def session_run(session, inputs)
|
163
|
+
checked_inputs = validate_inputs(session, inputs)
|
164
|
+
begin
|
165
|
+
output = session.run(@output_names, checked_inputs)
|
166
|
+
output = replace_tensors(output)
|
167
|
+
output
|
168
|
+
rescue => e
|
169
|
+
raise e
|
170
|
+
end
|
171
|
+
end
|
172
|
+
|
173
|
+
# TODO
|
174
|
+
def replace_tensors(obj)
|
175
|
+
obj
|
176
|
+
end
|
177
|
+
|
178
|
+
# TODO
|
179
|
+
def validate_inputs(session, inputs)
|
180
|
+
inputs
|
181
|
+
end
|
182
|
+
end
|
183
|
+
|
184
|
+
class BertPreTrainedModel < PreTrainedModel
|
185
|
+
end
|
186
|
+
|
187
|
+
class BertModel < BertPreTrainedModel
|
188
|
+
end
|
189
|
+
|
190
|
+
class BertForSequenceClassification < BertPreTrainedModel
|
191
|
+
def call(model_inputs)
|
192
|
+
SequenceClassifierOutput.new(*super(model_inputs))
|
193
|
+
end
|
194
|
+
end
|
195
|
+
|
196
|
+
class BertForTokenClassification < BertPreTrainedModel
|
197
|
+
def call(model_inputs)
|
198
|
+
TokenClassifierOutput.new(*super(model_inputs))
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
class DistilBertPreTrainedModel < PreTrainedModel
|
203
|
+
end
|
204
|
+
|
205
|
+
class DistilBertModel < DistilBertPreTrainedModel
|
206
|
+
end
|
207
|
+
|
208
|
+
class DistilBertForSequenceClassification < DistilBertPreTrainedModel
|
209
|
+
def call(model_inputs)
|
210
|
+
SequenceClassifierOutput.new(*super(model_inputs))
|
211
|
+
end
|
212
|
+
end
|
213
|
+
|
214
|
+
class DistilBertForQuestionAnswering < DistilBertPreTrainedModel
|
215
|
+
def call(model_inputs)
|
216
|
+
QuestionAnsweringModelOutput.new(*super(model_inputs))
|
217
|
+
end
|
218
|
+
end
|
219
|
+
|
220
|
+
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
|
221
|
+
"bert" => ["BertForSequenceClassification", BertForSequenceClassification],
|
222
|
+
"distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification]
|
223
|
+
}
|
224
|
+
|
225
|
+
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
|
226
|
+
"bert" => ["BertForTokenClassification", BertForTokenClassification]
|
227
|
+
}
|
228
|
+
|
229
|
+
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
|
230
|
+
"distilbert" => ["DistilBertForQuestionAnswering", DistilBertForQuestionAnswering]
|
231
|
+
}
|
232
|
+
|
233
|
+
MODEL_CLASS_TYPE_MAPPING = [
|
234
|
+
[MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
235
|
+
[MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
236
|
+
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
|
237
|
+
]
|
238
|
+
|
239
|
+
MODEL_CLASS_TYPE_MAPPING.each do |mappings, type|
|
240
|
+
mappings.values.each do |name, model|
|
241
|
+
MODEL_TYPE_MAPPING[name] = type
|
242
|
+
MODEL_CLASS_TO_NAME_MAPPING[model] = name
|
243
|
+
MODEL_NAME_TO_CLASS_MAPPING[name] = model
|
244
|
+
end
|
245
|
+
end
|
246
|
+
|
247
|
+
class AutoModel < PretrainedMixin
|
248
|
+
MODEL_CLASS_MAPPINGS = MODEL_CLASS_TYPE_MAPPING.map { |x| x[0] }
|
249
|
+
BASE_IF_FAIL = true
|
250
|
+
end
|
251
|
+
|
252
|
+
class AutoModelForSequenceClassification < PretrainedMixin
|
253
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES]
|
254
|
+
end
|
255
|
+
|
256
|
+
class AutoModelForTokenClassification < PretrainedMixin
|
257
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]
|
258
|
+
end
|
259
|
+
|
260
|
+
class AutoModelForQuestionAnswering < PretrainedMixin
|
261
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]
|
262
|
+
end
|
263
|
+
|
264
|
+
class ModelOutput
|
265
|
+
end
|
266
|
+
|
267
|
+
class SequenceClassifierOutput < ModelOutput
|
268
|
+
attr_reader :logits
|
269
|
+
|
270
|
+
def initialize(logits)
|
271
|
+
super()
|
272
|
+
@logits = logits
|
273
|
+
end
|
274
|
+
end
|
275
|
+
|
276
|
+
class TokenClassifierOutput < ModelOutput
|
277
|
+
attr_reader :logits
|
278
|
+
|
279
|
+
def initialize(logits)
|
280
|
+
super()
|
281
|
+
@logits = logits
|
282
|
+
end
|
283
|
+
end
|
284
|
+
|
285
|
+
class QuestionAnsweringModelOutput < ModelOutput
|
286
|
+
attr_reader :start_logits, :end_logits
|
287
|
+
|
288
|
+
def initialize(start_logits, end_logits)
|
289
|
+
super()
|
290
|
+
@start_logits = start_logits
|
291
|
+
@end_logits = end_logits
|
292
|
+
end
|
293
|
+
end
|
294
|
+
end
|