informers 0.2.0 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 22f7bcebf0670078b65fdf9cba4d2b937c853a3b10cf36e47f50781e2663225c
4
- data.tar.gz: 940c96ec6b749b7e0b0c283456e40bfe9e6cbb3a58e8fa11f6367e87b05d8694
3
+ metadata.gz: 37ea3d1f5f6e4988731e3c3dd5854ede2fb0211a5dbde18fe70d09a713b12a1c
4
+ data.tar.gz: ac7b05dc9364e1984d35ccbfc2b7604d8ec9dc76f0f8c1a33f21ba489deed8f4
5
5
  SHA512:
6
- metadata.gz: 4cd8b58aae6e885409e297bc1ba09aedd029bb3dc26a193251f33c2bf6c9f6a8da69cb3727f799296a8c6644b014afc715e783a1e19a1074982af531e40db57b
7
- data.tar.gz: 6f63489d0b303e9a7de13df11d5074bd4cb2dfa44febee4061262d5c188eeb62a7c975e89567048f801fa183c8d56925275768fccc9a4b5a48255abeeb379345
6
+ metadata.gz: dcd02d4ff94ed472713de26e781cfbf963136eb07da1a9a195c4482c585e1b8ab19875583118f33669b10005bf08f607c09af040b3f53bbed896fb6d19fcf9e4
7
+ data.tar.gz: 990ea77bf9fdf859354d5532d0a1acefec6576b1d322efb41b27a43aa06b1f0fa2dea81825d0ca3631969bfd0aaf1323091defddc7ed951557370095ab7d209b
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 1.0.0 (2024-08-26)
2
+
3
+ - Replaced task classes with `pipeline` method
4
+ - Added `Model` class
5
+ - Dropped support for Ruby < 3.1
6
+
1
7
  ## 0.2.0 (2022-09-06)
2
8
 
3
9
  - Added support for `optimum` and `transformers.onnx` models
data/README.md CHANGED
@@ -1,15 +1,10 @@
1
1
  # Informers
2
2
 
3
- :slightly_smiling_face: State-of-the-art natural language processing for Ruby
3
+ :fire: Fast [transformer](https://github.com/xenova/transformers.js) inference for Ruby
4
4
 
5
- Supports:
5
+ For non-ONNX models, check out [Transformers.rb](https://github.com/ankane/transformers-ruby)
6
6
 
7
- - Sentiment analysis
8
- - Question answering
9
- - Named-entity recognition
10
- - Text generation
11
-
12
- [![Build Status](https://github.com/ankane/informers/workflows/build/badge.svg?branch=master)](https://github.com/ankane/informers/actions)
7
+ [![Build Status](https://github.com/ankane/informers/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/informers/actions)
13
8
 
14
9
  ## Installation
15
10
 
@@ -21,140 +16,111 @@ gem "informers"
21
16
 
22
17
  ## Getting Started
23
18
 
24
- - [Sentiment analysis](#sentiment-analysis)
25
- - [Question answering](#question-answering)
26
- - [Named-entity recognition](#named-entity-recognition)
27
- - [Text generation](#text-generation)
28
- - [Feature extraction](#feature-extraction)
29
- - [Fill mask](#fill-mask)
19
+ - [Models](#models)
20
+ - [Pipelines](#pipelines)
30
21
 
31
- ### Sentiment Analysis
22
+ ## Models
32
23
 
33
- First, download the [pretrained model](https://github.com/ankane/informers/releases/download/v0.1.0/sentiment-analysis.onnx).
24
+ ### sentence-transformers/all-MiniLM-L6-v2
34
25
 
35
- Predict sentiment
26
+ [Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
36
27
 
37
28
  ```ruby
38
- model = Informers::SentimentAnalysis.new("sentiment-analysis.onnx")
39
- model.predict("This is super cool")
40
- ```
29
+ sentences = ["This is an example sentence", "Each sentence is converted"]
41
30
 
42
- This returns
43
-
44
- ```ruby
45
- {label: "positive", score: 0.999855186578301}
31
+ model = Informers::Model.new("sentence-transformers/all-MiniLM-L6-v2")
32
+ embeddings = model.embed(sentences)
46
33
  ```
47
34
 
48
- Predict multiple at once
35
+ For a quantized version, use:
49
36
 
50
37
  ```ruby
51
- model.predict(["This is super cool", "I didn't like it"])
38
+ model = Informers::Model.new("Xenova/all-MiniLM-L6-v2", quantized: true)
52
39
  ```
53
40
 
54
- ### Question Answering
55
-
56
- First, download the [pretrained model](https://github.com/ankane/informers/releases/download/v0.1.0/question-answering.onnx).
57
-
58
- Ask a question with some context
41
+ ### Xenova/multi-qa-MiniLM-L6-cos-v1
59
42
 
60
- ```ruby
61
- model = Informers::QuestionAnswering.new("question-answering.onnx")
62
- model.predict(
63
- question: "Who invented Ruby?",
64
- context: "Ruby is a programming language created by Matz"
65
- )
66
- ```
67
-
68
- This returns
43
+ [Docs](https://huggingface.co/Xenova/multi-qa-MiniLM-L6-cos-v1)
69
44
 
70
45
  ```ruby
71
- {answer: "Matz", score: 0.9980658360049758, start: 42, end: 46}
46
+ query = "How many people live in London?"
47
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
48
+
49
+ model = Informers::Model.new("Xenova/multi-qa-MiniLM-L6-cos-v1")
50
+ query_embedding = model.embed(query)
51
+ doc_embeddings = model.embed(docs)
52
+ scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
53
+ doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
72
54
  ```
73
55
 
74
- ### Named-Entity Recognition
56
+ ### mixedbread-ai/mxbai-embed-large-v1
75
57
 
76
- First, export the [pretrained model](tools/export.md).
77
-
78
- Get entities
58
+ [Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
79
59
 
80
60
  ```ruby
81
- model = Informers::NER.new("ner.onnx")
82
- model.predict("Nat works at GitHub in San Francisco")
83
- ```
84
-
85
- This returns
86
-
87
- ```ruby
88
- [
89
- {text: "Nat", tag: "person", score: 0.9840519576513487, start: 0, end: 3},
90
- {text: "GitHub", tag: "org", score: 0.9426134775785775, start: 13, end: 19},
91
- {text: "San Francisco", tag: "location", score: 0.9952414982243061, start: 23, end: 36}
61
+ def transform_query(query)
62
+ "Represent this sentence for searching relevant passages: #{query}"
63
+ end
64
+
65
+ docs = [
66
+ transform_query("puppy"),
67
+ "The dog is barking",
68
+ "The cat is purring"
92
69
  ]
93
- ```
94
70
 
95
- ### Text Generation
71
+ model = Informers::Model.new("mixedbread-ai/mxbai-embed-large-v1")
72
+ embeddings = model.embed(docs)
73
+ ```
96
74
 
97
- First, export the [pretrained model](tools/export.md).
75
+ ## Pipelines
98
76
 
99
- Pass a prompt
77
+ Named-entity recognition
100
78
 
101
79
  ```ruby
102
- model = Informers::TextGeneration.new("text-generation.onnx")
103
- model.predict("As far as I am concerned, I will", max_length: 50)
80
+ ner = Informers.pipeline("ner")
81
+ ner.("Ruby is a programming language created by Matz")
104
82
  ```
105
83
 
106
- This returns
84
+ Sentiment analysis
107
85
 
108
- ```text
109
- As far as I am concerned, I will be the first to admit that I am not a fan of the idea of a "free market." I think that the idea of a free market is a bit of a stretch. I think that the idea
86
+ ```ruby
87
+ classifier = Informers.pipeline("sentiment-analysis")
88
+ classifier.("We are very happy to show you the 🤗 Transformers library.")
110
89
  ```
111
90
 
112
- ### Feature Extraction
113
-
114
- First, export a [pretrained model](tools/export.md).
91
+ Question answering
115
92
 
116
93
  ```ruby
117
- model = Informers::FeatureExtraction.new("feature-extraction.onnx")
118
- model.predict("This is super cool")
94
+ qa = Informers.pipeline("question-answering")
95
+ qa.("Who invented Ruby?", "Ruby is a programming language created by Matz")
119
96
  ```
120
97
 
121
- ### Fill Mask
122
-
123
- First, export a [pretrained model](tools/export.md).
98
+ Feature extraction
124
99
 
125
100
  ```ruby
126
- model = Informers::FillMask.new("fill-mask.onnx")
127
- model.predict("This is a great <mask>")
101
+ extractor = Informers.pipeline("feature-extraction")
102
+ extractor.("We are very happy to show you the 🤗 Transformers library.")
128
103
  ```
129
104
 
130
- ## Models
131
-
132
- Task | Description | Contributor | License | Link
133
- --- | --- | --- | --- | ---
134
- Sentiment analysis | DistilBERT fine-tuned on SST-2 | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english)
135
- Question answering | DistilBERT fine-tuned on SQuAD | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-cased-distilled-squad)
136
- Named-entity recognition | BERT fine-tuned on CoNLL03 | Bayerische Staatsbibliothek | In-progress | [Link](https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english)
137
- Text generation | GPT-2 | OpenAI | [Custom](https://github.com/openai/gpt-2/blob/master/LICENSE) | [Link](https://huggingface.co/gpt2)
138
-
139
- Some models are [quantized](https://medium.com/microsoftazure/faster-and-smaller-quantized-nlp-with-hugging-face-and-onnx-runtime-ec5525473bb7) to make them faster and smaller.
140
-
141
- ## Deployment
105
+ ## Credits
142
106
 
143
- Check out [Trove](https://github.com/ankane/trove) for deploying models.
107
+ This library was ported from [Transformers.js](https://github.com/xenova/transformers.js) and is available under the same license.
144
108
 
145
- ```sh
146
- trove push sentiment-analysis.onnx
147
- ```
109
+ ## Upgrading
148
110
 
149
- ## Credits
111
+ ### 1.0
150
112
 
151
- This project uses many state-of-the-art technologies:
113
+ Task classes have been replaced with the `pipeline` method.
152
114
 
153
- - [Transformers](https://github.com/huggingface/transformers) for transformer models
154
- - [Bling Fire](https://github.com/microsoft/BlingFire) and [BERT](https://github.com/google-research/bert) for high-performance text tokenization
155
- - [ONNX Runtime](https://github.com/Microsoft/onnxruntime) for high-performance inference
115
+ ```ruby
116
+ # before
117
+ model = Informers::SentimentAnalysis.new("sentiment-analysis.onnx")
118
+ model.predict("This is super cool")
156
119
 
157
- Some code was ported from Transformers and is available under the same license.
120
+ # after
121
+ model = Informers.pipeline("sentiment-analysis")
122
+ model.("This is super cool")
123
+ ```
158
124
 
159
125
  ## History
160
126
 
@@ -175,7 +141,5 @@ To get started with development:
175
141
  git clone https://github.com/ankane/informers.git
176
142
  cd informers
177
143
  bundle install
178
-
179
- export MODELS_PATH=path/to/onnx/models
180
144
  bundle exec rake test
181
145
  ```
@@ -0,0 +1,48 @@
1
+ module Informers
2
+ class PretrainedConfig
3
+ attr_reader :model_type, :problem_type, :id2label
4
+
5
+ def initialize(config_json)
6
+ @is_encoder_decoder = false
7
+
8
+ @model_type = config_json["model_type"]
9
+ @problem_type = config_json["problem_type"]
10
+ @id2label = config_json["id2label"]
11
+ end
12
+
13
+ def [](key)
14
+ instance_variable_get("@#{key}")
15
+ end
16
+
17
+ def self.from_pretrained(
18
+ pretrained_model_name_or_path,
19
+ progress_callback: nil,
20
+ config: nil,
21
+ cache_dir: nil,
22
+ local_files_only: false,
23
+ revision: "main",
24
+ **kwargs
25
+ )
26
+ data = config || load_config(
27
+ pretrained_model_name_or_path,
28
+ progress_callback:,
29
+ config:,
30
+ cache_dir:,
31
+ local_files_only:,
32
+ revision:
33
+ )
34
+ new(data)
35
+ end
36
+
37
+ def self.load_config(pretrained_model_name_or_path, **options)
38
+ info = Utils::Hub.get_model_json(pretrained_model_name_or_path, "config.json", true, **options)
39
+ info
40
+ end
41
+ end
42
+
43
+ class AutoConfig
44
+ def self.from_pretrained(...)
45
+ PretrainedConfig.from_pretrained(...)
46
+ end
47
+ end
48
+ end
@@ -0,0 +1,14 @@
1
+ module Informers
2
+ CACHE_HOME = ENV.fetch("XDG_CACHE_HOME", File.join(ENV.fetch("HOME"), ".cache"))
3
+ DEFAULT_CACHE_DIR = File.expand_path(File.join(CACHE_HOME, "informers"))
4
+
5
+ class << self
6
+ attr_accessor :allow_remote_models, :remote_host, :remote_path_template, :cache_dir
7
+ end
8
+
9
+ self.allow_remote_models = ENV["INFORMERS_OFFLINE"].to_s.empty?
10
+ self.remote_host = "https://huggingface.co/"
11
+ self.remote_path_template = "{model}/resolve/{revision}/"
12
+
13
+ self.cache_dir = DEFAULT_CACHE_DIR
14
+ end
@@ -0,0 +1,31 @@
1
+ module Informers
2
+ class Model
3
+ def initialize(model_id, quantized: false)
4
+ @model_id = model_id
5
+ @model = Informers.pipeline("feature-extraction", model_id, quantized: quantized)
6
+
7
+ # TODO better pattern
8
+ if model_id == "sentence-transformers/all-MiniLM-L6-v2"
9
+ @model.instance_variable_get(:@model).instance_variable_set(:@output_names, ["sentence_embedding"])
10
+ end
11
+ end
12
+
13
+ def embed(texts)
14
+ is_batched = texts.is_a?(Array)
15
+ texts = [texts] unless is_batched
16
+
17
+ case @model_id
18
+ when "sentence-transformers/all-MiniLM-L6-v2"
19
+ output = @model.(texts)
20
+ when "Xenova/all-MiniLM-L6-v2", "Xenova/multi-qa-MiniLM-L6-cos-v1"
21
+ output = @model.(texts, pooling: "mean", normalize: true)
22
+ when "mixedbread-ai/mxbai-embed-large-v1"
23
+ output = @model.(texts, pooling: "cls")
24
+ else
25
+ raise Error, "model not supported: #{@model_id}"
26
+ end
27
+
28
+ is_batched ? output : output[0]
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,294 @@
1
+ module Informers
2
+ MODEL_TYPES = {
3
+ EncoderOnly: 0,
4
+ EncoderDecoder: 1,
5
+ Seq2Seq: 2,
6
+ Vision2Seq: 3,
7
+ DecoderOnly: 4,
8
+ MaskGeneration: 5
9
+ }
10
+
11
+ # NOTE: These will be populated fully later
12
+ MODEL_TYPE_MAPPING = {}
13
+ MODEL_NAME_TO_CLASS_MAPPING = {}
14
+ MODEL_CLASS_TO_NAME_MAPPING = {}
15
+
16
+ class PretrainedMixin
17
+ def self.from_pretrained(
18
+ pretrained_model_name_or_path,
19
+ quantized: true,
20
+ progress_callback: nil,
21
+ config: nil,
22
+ cache_dir: nil,
23
+ local_files_only: false,
24
+ revision: "main",
25
+ model_file_name: nil
26
+ )
27
+ options = {
28
+ quantized:,
29
+ progress_callback:,
30
+ config:,
31
+ cache_dir:,
32
+ local_files_only:,
33
+ revision:,
34
+ model_file_name:
35
+ }
36
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **options)
37
+ if options[:config].nil?
38
+ # If no config was passed, reuse this config for future processing
39
+ options[:config] = config
40
+ end
41
+
42
+ if !const_defined?(:MODEL_CLASS_MAPPINGS)
43
+ raise Error, "`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: #{name}"
44
+ end
45
+
46
+ const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|
47
+ model_info = model_class_mapping[config.model_type]
48
+ if !model_info
49
+ next # Item not found in this mapping
50
+ end
51
+ return model_info[1].from_pretrained(pretrained_model_name_or_path, **options)
52
+ end
53
+
54
+ if const_defined?(:BASE_IF_FAIL)
55
+ warn "Unknown model class #{config.model_type.inspect}, attempting to construct from base class."
56
+ PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)
57
+ else
58
+ raise Error, "Unsupported model type: #{config.model_type}"
59
+ end
60
+ end
61
+ end
62
+
63
+ class PreTrainedModel
64
+ attr_reader :config
65
+
66
+ def initialize(config, session)
67
+ super()
68
+
69
+ @config = config
70
+ @session = session
71
+
72
+ @output_names = nil
73
+
74
+ model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
75
+ model_type = MODEL_TYPE_MAPPING[model_name]
76
+
77
+ case model_type
78
+ when MODEL_TYPES[:DecoderOnly]
79
+ raise Todo
80
+ when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
81
+ raise Todo
82
+ when MODEL_TYPES[:EncoderDecoder]
83
+ raise Todo
84
+ else
85
+ @forward = method(:encoder_forward)
86
+ end
87
+ end
88
+
89
+ def self.from_pretrained(
90
+ pretrained_model_name_or_path,
91
+ quantized: true,
92
+ progress_callback: nil,
93
+ config: nil,
94
+ cache_dir: nil,
95
+ local_files_only: false,
96
+ revision: "main",
97
+ model_file_name: nil
98
+ )
99
+ options = {
100
+ quantized:,
101
+ progress_callback:,
102
+ config:,
103
+ cache_dir:,
104
+ local_files_only:,
105
+ revision:,
106
+ model_file_name:
107
+ }
108
+
109
+ model_name = MODEL_CLASS_TO_NAME_MAPPING[self]
110
+ model_type = MODEL_TYPE_MAPPING[model_name]
111
+
112
+ if model_type == MODEL_TYPES[:DecoderOnly]
113
+ raise Todo
114
+
115
+ elsif model_type == MODEL_TYPES[:Seq2Seq] || model_type == MODEL_TYPES[:Vision2Seq]
116
+ raise Todo
117
+
118
+ elsif model_type == MODEL_TYPES[:MaskGeneration]
119
+ raise Todo
120
+
121
+ elsif model_type == MODEL_TYPES[:EncoderDecoder]
122
+ raise Todo
123
+
124
+ else
125
+ if model_type != MODEL_TYPES[:EncoderOnly]
126
+ warn "Model type for '#{model_name || config&.model_type}' not found, assuming encoder-only architecture. Please report this."
127
+ end
128
+ info = [
129
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
130
+ construct_session(pretrained_model_name_or_path, options[:model_file_name] || "model", **options)
131
+ ]
132
+ end
133
+
134
+ new(*info)
135
+ end
136
+
137
+ def self.construct_session(pretrained_model_name_or_path, file_name, **options)
138
+ model_file_name = "onnx/#{file_name}#{options[:quantized] ? "_quantized" : ""}.onnx"
139
+ path = Utils::Hub.get_model_file(pretrained_model_name_or_path, model_file_name, true, **options)
140
+
141
+ OnnxRuntime::InferenceSession.new(path)
142
+ end
143
+
144
+ def call(model_inputs)
145
+ @forward.(model_inputs)
146
+ end
147
+
148
+ private
149
+
150
+ def encoder_forward(model_inputs)
151
+ encoder_feeds = {}
152
+ @session.inputs.each do |input|
153
+ key = input[:name].to_sym
154
+ encoder_feeds[key] = model_inputs[key]
155
+ end
156
+ if @session.inputs.any? { |v| v[:name] == "token_type_ids" } && !encoder_feeds[:token_type_ids]
157
+ raise Todo
158
+ end
159
+ session_run(@session, encoder_feeds)
160
+ end
161
+
162
+ def session_run(session, inputs)
163
+ checked_inputs = validate_inputs(session, inputs)
164
+ begin
165
+ output = session.run(@output_names, checked_inputs)
166
+ output = replace_tensors(output)
167
+ output
168
+ rescue => e
169
+ raise e
170
+ end
171
+ end
172
+
173
+ # TODO
174
+ def replace_tensors(obj)
175
+ obj
176
+ end
177
+
178
+ # TODO
179
+ def validate_inputs(session, inputs)
180
+ inputs
181
+ end
182
+ end
183
+
184
+ class BertPreTrainedModel < PreTrainedModel
185
+ end
186
+
187
+ class BertModel < BertPreTrainedModel
188
+ end
189
+
190
+ class BertForSequenceClassification < BertPreTrainedModel
191
+ def call(model_inputs)
192
+ SequenceClassifierOutput.new(*super(model_inputs))
193
+ end
194
+ end
195
+
196
+ class BertForTokenClassification < BertPreTrainedModel
197
+ def call(model_inputs)
198
+ TokenClassifierOutput.new(*super(model_inputs))
199
+ end
200
+ end
201
+
202
+ class DistilBertPreTrainedModel < PreTrainedModel
203
+ end
204
+
205
+ class DistilBertModel < DistilBertPreTrainedModel
206
+ end
207
+
208
+ class DistilBertForSequenceClassification < DistilBertPreTrainedModel
209
+ def call(model_inputs)
210
+ SequenceClassifierOutput.new(*super(model_inputs))
211
+ end
212
+ end
213
+
214
+ class DistilBertForQuestionAnswering < DistilBertPreTrainedModel
215
+ def call(model_inputs)
216
+ QuestionAnsweringModelOutput.new(*super(model_inputs))
217
+ end
218
+ end
219
+
220
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
221
+ "bert" => ["BertForSequenceClassification", BertForSequenceClassification],
222
+ "distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification]
223
+ }
224
+
225
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
226
+ "bert" => ["BertForTokenClassification", BertForTokenClassification]
227
+ }
228
+
229
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
230
+ "distilbert" => ["DistilBertForQuestionAnswering", DistilBertForQuestionAnswering]
231
+ }
232
+
233
+ MODEL_CLASS_TYPE_MAPPING = [
234
+ [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
235
+ [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
236
+ [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
237
+ ]
238
+
239
+ MODEL_CLASS_TYPE_MAPPING.each do |mappings, type|
240
+ mappings.values.each do |name, model|
241
+ MODEL_TYPE_MAPPING[name] = type
242
+ MODEL_CLASS_TO_NAME_MAPPING[model] = name
243
+ MODEL_NAME_TO_CLASS_MAPPING[name] = model
244
+ end
245
+ end
246
+
247
+ class AutoModel < PretrainedMixin
248
+ MODEL_CLASS_MAPPINGS = MODEL_CLASS_TYPE_MAPPING.map { |x| x[0] }
249
+ BASE_IF_FAIL = true
250
+ end
251
+
252
+ class AutoModelForSequenceClassification < PretrainedMixin
253
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES]
254
+ end
255
+
256
+ class AutoModelForTokenClassification < PretrainedMixin
257
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]
258
+ end
259
+
260
+ class AutoModelForQuestionAnswering < PretrainedMixin
261
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]
262
+ end
263
+
264
+ class ModelOutput
265
+ end
266
+
267
+ class SequenceClassifierOutput < ModelOutput
268
+ attr_reader :logits
269
+
270
+ def initialize(logits)
271
+ super()
272
+ @logits = logits
273
+ end
274
+ end
275
+
276
+ class TokenClassifierOutput < ModelOutput
277
+ attr_reader :logits
278
+
279
+ def initialize(logits)
280
+ super()
281
+ @logits = logits
282
+ end
283
+ end
284
+
285
+ class QuestionAnsweringModelOutput < ModelOutput
286
+ attr_reader :start_logits, :end_logits
287
+
288
+ def initialize(start_logits, end_logits)
289
+ super()
290
+ @start_logits = start_logits
291
+ @end_logits = end_logits
292
+ end
293
+ end
294
+ end