informers 0.2.0 → 1.0.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 22f7bcebf0670078b65fdf9cba4d2b937c853a3b10cf36e47f50781e2663225c
4
- data.tar.gz: 940c96ec6b749b7e0b0c283456e40bfe9e6cbb3a58e8fa11f6367e87b05d8694
3
+ metadata.gz: 37ea3d1f5f6e4988731e3c3dd5854ede2fb0211a5dbde18fe70d09a713b12a1c
4
+ data.tar.gz: ac7b05dc9364e1984d35ccbfc2b7604d8ec9dc76f0f8c1a33f21ba489deed8f4
5
5
  SHA512:
6
- metadata.gz: 4cd8b58aae6e885409e297bc1ba09aedd029bb3dc26a193251f33c2bf6c9f6a8da69cb3727f799296a8c6644b014afc715e783a1e19a1074982af531e40db57b
7
- data.tar.gz: 6f63489d0b303e9a7de13df11d5074bd4cb2dfa44febee4061262d5c188eeb62a7c975e89567048f801fa183c8d56925275768fccc9a4b5a48255abeeb379345
6
+ metadata.gz: dcd02d4ff94ed472713de26e781cfbf963136eb07da1a9a195c4482c585e1b8ab19875583118f33669b10005bf08f607c09af040b3f53bbed896fb6d19fcf9e4
7
+ data.tar.gz: 990ea77bf9fdf859354d5532d0a1acefec6576b1d322efb41b27a43aa06b1f0fa2dea81825d0ca3631969bfd0aaf1323091defddc7ed951557370095ab7d209b
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 1.0.0 (2024-08-26)
2
+
3
+ - Replaced task classes with `pipeline` method
4
+ - Added `Model` class
5
+ - Dropped support for Ruby < 3.1
6
+
1
7
  ## 0.2.0 (2022-09-06)
2
8
 
3
9
  - Added support for `optimum` and `transformers.onnx` models
data/README.md CHANGED
@@ -1,15 +1,10 @@
1
1
  # Informers
2
2
 
3
- :slightly_smiling_face: State-of-the-art natural language processing for Ruby
3
+ :fire: Fast [transformer](https://github.com/xenova/transformers.js) inference for Ruby
4
4
 
5
- Supports:
5
+ For non-ONNX models, check out [Transformers.rb](https://github.com/ankane/transformers-ruby)
6
6
 
7
- - Sentiment analysis
8
- - Question answering
9
- - Named-entity recognition
10
- - Text generation
11
-
12
- [![Build Status](https://github.com/ankane/informers/workflows/build/badge.svg?branch=master)](https://github.com/ankane/informers/actions)
7
+ [![Build Status](https://github.com/ankane/informers/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/informers/actions)
13
8
 
14
9
  ## Installation
15
10
 
@@ -21,140 +16,111 @@ gem "informers"
21
16
 
22
17
  ## Getting Started
23
18
 
24
- - [Sentiment analysis](#sentiment-analysis)
25
- - [Question answering](#question-answering)
26
- - [Named-entity recognition](#named-entity-recognition)
27
- - [Text generation](#text-generation)
28
- - [Feature extraction](#feature-extraction)
29
- - [Fill mask](#fill-mask)
19
+ - [Models](#models)
20
+ - [Pipelines](#pipelines)
30
21
 
31
- ### Sentiment Analysis
22
+ ## Models
32
23
 
33
- First, download the [pretrained model](https://github.com/ankane/informers/releases/download/v0.1.0/sentiment-analysis.onnx).
24
+ ### sentence-transformers/all-MiniLM-L6-v2
34
25
 
35
- Predict sentiment
26
+ [Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
36
27
 
37
28
  ```ruby
38
- model = Informers::SentimentAnalysis.new("sentiment-analysis.onnx")
39
- model.predict("This is super cool")
40
- ```
29
+ sentences = ["This is an example sentence", "Each sentence is converted"]
41
30
 
42
- This returns
43
-
44
- ```ruby
45
- {label: "positive", score: 0.999855186578301}
31
+ model = Informers::Model.new("sentence-transformers/all-MiniLM-L6-v2")
32
+ embeddings = model.embed(sentences)
46
33
  ```
47
34
 
48
- Predict multiple at once
35
+ For a quantized version, use:
49
36
 
50
37
  ```ruby
51
- model.predict(["This is super cool", "I didn't like it"])
38
+ model = Informers::Model.new("Xenova/all-MiniLM-L6-v2", quantized: true)
52
39
  ```
53
40
 
54
- ### Question Answering
55
-
56
- First, download the [pretrained model](https://github.com/ankane/informers/releases/download/v0.1.0/question-answering.onnx).
57
-
58
- Ask a question with some context
41
+ ### Xenova/multi-qa-MiniLM-L6-cos-v1
59
42
 
60
- ```ruby
61
- model = Informers::QuestionAnswering.new("question-answering.onnx")
62
- model.predict(
63
- question: "Who invented Ruby?",
64
- context: "Ruby is a programming language created by Matz"
65
- )
66
- ```
67
-
68
- This returns
43
+ [Docs](https://huggingface.co/Xenova/multi-qa-MiniLM-L6-cos-v1)
69
44
 
70
45
  ```ruby
71
- {answer: "Matz", score: 0.9980658360049758, start: 42, end: 46}
46
+ query = "How many people live in London?"
47
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
48
+
49
+ model = Informers::Model.new("Xenova/multi-qa-MiniLM-L6-cos-v1")
50
+ query_embedding = model.embed(query)
51
+ doc_embeddings = model.embed(docs)
52
+ scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
53
+ doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
72
54
  ```
73
55
 
74
- ### Named-Entity Recognition
56
+ ### mixedbread-ai/mxbai-embed-large-v1
75
57
 
76
- First, export the [pretrained model](tools/export.md).
77
-
78
- Get entities
58
+ [Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
79
59
 
80
60
  ```ruby
81
- model = Informers::NER.new("ner.onnx")
82
- model.predict("Nat works at GitHub in San Francisco")
83
- ```
84
-
85
- This returns
86
-
87
- ```ruby
88
- [
89
- {text: "Nat", tag: "person", score: 0.9840519576513487, start: 0, end: 3},
90
- {text: "GitHub", tag: "org", score: 0.9426134775785775, start: 13, end: 19},
91
- {text: "San Francisco", tag: "location", score: 0.9952414982243061, start: 23, end: 36}
61
+ def transform_query(query)
62
+ "Represent this sentence for searching relevant passages: #{query}"
63
+ end
64
+
65
+ docs = [
66
+ transform_query("puppy"),
67
+ "The dog is barking",
68
+ "The cat is purring"
92
69
  ]
93
- ```
94
70
 
95
- ### Text Generation
71
+ model = Informers::Model.new("mixedbread-ai/mxbai-embed-large-v1")
72
+ embeddings = model.embed(docs)
73
+ ```
96
74
 
97
- First, export the [pretrained model](tools/export.md).
75
+ ## Pipelines
98
76
 
99
- Pass a prompt
77
+ Named-entity recognition
100
78
 
101
79
  ```ruby
102
- model = Informers::TextGeneration.new("text-generation.onnx")
103
- model.predict("As far as I am concerned, I will", max_length: 50)
80
+ ner = Informers.pipeline("ner")
81
+ ner.("Ruby is a programming language created by Matz")
104
82
  ```
105
83
 
106
- This returns
84
+ Sentiment analysis
107
85
 
108
- ```text
109
- As far as I am concerned, I will be the first to admit that I am not a fan of the idea of a "free market." I think that the idea of a free market is a bit of a stretch. I think that the idea
86
+ ```ruby
87
+ classifier = Informers.pipeline("sentiment-analysis")
88
+ classifier.("We are very happy to show you the 🤗 Transformers library.")
110
89
  ```
111
90
 
112
- ### Feature Extraction
113
-
114
- First, export a [pretrained model](tools/export.md).
91
+ Question answering
115
92
 
116
93
  ```ruby
117
- model = Informers::FeatureExtraction.new("feature-extraction.onnx")
118
- model.predict("This is super cool")
94
+ qa = Informers.pipeline("question-answering")
95
+ qa.("Who invented Ruby?", "Ruby is a programming language created by Matz")
119
96
  ```
120
97
 
121
- ### Fill Mask
122
-
123
- First, export a [pretrained model](tools/export.md).
98
+ Feature extraction
124
99
 
125
100
  ```ruby
126
- model = Informers::FillMask.new("fill-mask.onnx")
127
- model.predict("This is a great <mask>")
101
+ extractor = Informers.pipeline("feature-extraction")
102
+ extractor.("We are very happy to show you the 🤗 Transformers library.")
128
103
  ```
129
104
 
130
- ## Models
131
-
132
- Task | Description | Contributor | License | Link
133
- --- | --- | --- | --- | ---
134
- Sentiment analysis | DistilBERT fine-tuned on SST-2 | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english)
135
- Question answering | DistilBERT fine-tuned on SQuAD | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-cased-distilled-squad)
136
- Named-entity recognition | BERT fine-tuned on CoNLL03 | Bayerische Staatsbibliothek | In-progress | [Link](https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english)
137
- Text generation | GPT-2 | OpenAI | [Custom](https://github.com/openai/gpt-2/blob/master/LICENSE) | [Link](https://huggingface.co/gpt2)
138
-
139
- Some models are [quantized](https://medium.com/microsoftazure/faster-and-smaller-quantized-nlp-with-hugging-face-and-onnx-runtime-ec5525473bb7) to make them faster and smaller.
140
-
141
- ## Deployment
105
+ ## Credits
142
106
 
143
- Check out [Trove](https://github.com/ankane/trove) for deploying models.
107
+ This library was ported from [Transformers.js](https://github.com/xenova/transformers.js) and is available under the same license.
144
108
 
145
- ```sh
146
- trove push sentiment-analysis.onnx
147
- ```
109
+ ## Upgrading
148
110
 
149
- ## Credits
111
+ ### 1.0
150
112
 
151
- This project uses many state-of-the-art technologies:
113
+ Task classes have been replaced with the `pipeline` method.
152
114
 
153
- - [Transformers](https://github.com/huggingface/transformers) for transformer models
154
- - [Bling Fire](https://github.com/microsoft/BlingFire) and [BERT](https://github.com/google-research/bert) for high-performance text tokenization
155
- - [ONNX Runtime](https://github.com/Microsoft/onnxruntime) for high-performance inference
115
+ ```ruby
116
+ # before
117
+ model = Informers::SentimentAnalysis.new("sentiment-analysis.onnx")
118
+ model.predict("This is super cool")
156
119
 
157
- Some code was ported from Transformers and is available under the same license.
120
+ # after
121
+ model = Informers.pipeline("sentiment-analysis")
122
+ model.("This is super cool")
123
+ ```
158
124
 
159
125
  ## History
160
126
 
@@ -175,7 +141,5 @@ To get started with development:
175
141
  git clone https://github.com/ankane/informers.git
176
142
  cd informers
177
143
  bundle install
178
-
179
- export MODELS_PATH=path/to/onnx/models
180
144
  bundle exec rake test
181
145
  ```
@@ -0,0 +1,48 @@
1
+ module Informers
2
+ class PretrainedConfig
3
+ attr_reader :model_type, :problem_type, :id2label
4
+
5
+ def initialize(config_json)
6
+ @is_encoder_decoder = false
7
+
8
+ @model_type = config_json["model_type"]
9
+ @problem_type = config_json["problem_type"]
10
+ @id2label = config_json["id2label"]
11
+ end
12
+
13
+ def [](key)
14
+ instance_variable_get("@#{key}")
15
+ end
16
+
17
+ def self.from_pretrained(
18
+ pretrained_model_name_or_path,
19
+ progress_callback: nil,
20
+ config: nil,
21
+ cache_dir: nil,
22
+ local_files_only: false,
23
+ revision: "main",
24
+ **kwargs
25
+ )
26
+ data = config || load_config(
27
+ pretrained_model_name_or_path,
28
+ progress_callback:,
29
+ config:,
30
+ cache_dir:,
31
+ local_files_only:,
32
+ revision:
33
+ )
34
+ new(data)
35
+ end
36
+
37
+ def self.load_config(pretrained_model_name_or_path, **options)
38
+ info = Utils::Hub.get_model_json(pretrained_model_name_or_path, "config.json", true, **options)
39
+ info
40
+ end
41
+ end
42
+
43
+ class AutoConfig
44
+ def self.from_pretrained(...)
45
+ PretrainedConfig.from_pretrained(...)
46
+ end
47
+ end
48
+ end
@@ -0,0 +1,14 @@
1
+ module Informers
2
+ CACHE_HOME = ENV.fetch("XDG_CACHE_HOME", File.join(ENV.fetch("HOME"), ".cache"))
3
+ DEFAULT_CACHE_DIR = File.expand_path(File.join(CACHE_HOME, "informers"))
4
+
5
+ class << self
6
+ attr_accessor :allow_remote_models, :remote_host, :remote_path_template, :cache_dir
7
+ end
8
+
9
+ self.allow_remote_models = ENV["INFORMERS_OFFLINE"].to_s.empty?
10
+ self.remote_host = "https://huggingface.co/"
11
+ self.remote_path_template = "{model}/resolve/{revision}/"
12
+
13
+ self.cache_dir = DEFAULT_CACHE_DIR
14
+ end
@@ -0,0 +1,31 @@
1
+ module Informers
2
+ class Model
3
+ def initialize(model_id, quantized: false)
4
+ @model_id = model_id
5
+ @model = Informers.pipeline("feature-extraction", model_id, quantized: quantized)
6
+
7
+ # TODO better pattern
8
+ if model_id == "sentence-transformers/all-MiniLM-L6-v2"
9
+ @model.instance_variable_get(:@model).instance_variable_set(:@output_names, ["sentence_embedding"])
10
+ end
11
+ end
12
+
13
+ def embed(texts)
14
+ is_batched = texts.is_a?(Array)
15
+ texts = [texts] unless is_batched
16
+
17
+ case @model_id
18
+ when "sentence-transformers/all-MiniLM-L6-v2"
19
+ output = @model.(texts)
20
+ when "Xenova/all-MiniLM-L6-v2", "Xenova/multi-qa-MiniLM-L6-cos-v1"
21
+ output = @model.(texts, pooling: "mean", normalize: true)
22
+ when "mixedbread-ai/mxbai-embed-large-v1"
23
+ output = @model.(texts, pooling: "cls")
24
+ else
25
+ raise Error, "model not supported: #{@model_id}"
26
+ end
27
+
28
+ is_batched ? output : output[0]
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,294 @@
1
+ module Informers
2
+ MODEL_TYPES = {
3
+ EncoderOnly: 0,
4
+ EncoderDecoder: 1,
5
+ Seq2Seq: 2,
6
+ Vision2Seq: 3,
7
+ DecoderOnly: 4,
8
+ MaskGeneration: 5
9
+ }
10
+
11
+ # NOTE: These will be populated fully later
12
+ MODEL_TYPE_MAPPING = {}
13
+ MODEL_NAME_TO_CLASS_MAPPING = {}
14
+ MODEL_CLASS_TO_NAME_MAPPING = {}
15
+
16
+ class PretrainedMixin
17
+ def self.from_pretrained(
18
+ pretrained_model_name_or_path,
19
+ quantized: true,
20
+ progress_callback: nil,
21
+ config: nil,
22
+ cache_dir: nil,
23
+ local_files_only: false,
24
+ revision: "main",
25
+ model_file_name: nil
26
+ )
27
+ options = {
28
+ quantized:,
29
+ progress_callback:,
30
+ config:,
31
+ cache_dir:,
32
+ local_files_only:,
33
+ revision:,
34
+ model_file_name:
35
+ }
36
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **options)
37
+ if options[:config].nil?
38
+ # If no config was passed, reuse this config for future processing
39
+ options[:config] = config
40
+ end
41
+
42
+ if !const_defined?(:MODEL_CLASS_MAPPINGS)
43
+ raise Error, "`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: #{name}"
44
+ end
45
+
46
+ const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|
47
+ model_info = model_class_mapping[config.model_type]
48
+ if !model_info
49
+ next # Item not found in this mapping
50
+ end
51
+ return model_info[1].from_pretrained(pretrained_model_name_or_path, **options)
52
+ end
53
+
54
+ if const_defined?(:BASE_IF_FAIL)
55
+ warn "Unknown model class #{config.model_type.inspect}, attempting to construct from base class."
56
+ PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)
57
+ else
58
+ raise Error, "Unsupported model type: #{config.model_type}"
59
+ end
60
+ end
61
+ end
62
+
63
+ class PreTrainedModel
64
+ attr_reader :config
65
+
66
+ def initialize(config, session)
67
+ super()
68
+
69
+ @config = config
70
+ @session = session
71
+
72
+ @output_names = nil
73
+
74
+ model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
75
+ model_type = MODEL_TYPE_MAPPING[model_name]
76
+
77
+ case model_type
78
+ when MODEL_TYPES[:DecoderOnly]
79
+ raise Todo
80
+ when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
81
+ raise Todo
82
+ when MODEL_TYPES[:EncoderDecoder]
83
+ raise Todo
84
+ else
85
+ @forward = method(:encoder_forward)
86
+ end
87
+ end
88
+
89
+ def self.from_pretrained(
90
+ pretrained_model_name_or_path,
91
+ quantized: true,
92
+ progress_callback: nil,
93
+ config: nil,
94
+ cache_dir: nil,
95
+ local_files_only: false,
96
+ revision: "main",
97
+ model_file_name: nil
98
+ )
99
+ options = {
100
+ quantized:,
101
+ progress_callback:,
102
+ config:,
103
+ cache_dir:,
104
+ local_files_only:,
105
+ revision:,
106
+ model_file_name:
107
+ }
108
+
109
+ model_name = MODEL_CLASS_TO_NAME_MAPPING[self]
110
+ model_type = MODEL_TYPE_MAPPING[model_name]
111
+
112
+ if model_type == MODEL_TYPES[:DecoderOnly]
113
+ raise Todo
114
+
115
+ elsif model_type == MODEL_TYPES[:Seq2Seq] || model_type == MODEL_TYPES[:Vision2Seq]
116
+ raise Todo
117
+
118
+ elsif model_type == MODEL_TYPES[:MaskGeneration]
119
+ raise Todo
120
+
121
+ elsif model_type == MODEL_TYPES[:EncoderDecoder]
122
+ raise Todo
123
+
124
+ else
125
+ if model_type != MODEL_TYPES[:EncoderOnly]
126
+ warn "Model type for '#{model_name || config&.model_type}' not found, assuming encoder-only architecture. Please report this."
127
+ end
128
+ info = [
129
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
130
+ construct_session(pretrained_model_name_or_path, options[:model_file_name] || "model", **options)
131
+ ]
132
+ end
133
+
134
+ new(*info)
135
+ end
136
+
137
+ def self.construct_session(pretrained_model_name_or_path, file_name, **options)
138
+ model_file_name = "onnx/#{file_name}#{options[:quantized] ? "_quantized" : ""}.onnx"
139
+ path = Utils::Hub.get_model_file(pretrained_model_name_or_path, model_file_name, true, **options)
140
+
141
+ OnnxRuntime::InferenceSession.new(path)
142
+ end
143
+
144
+ def call(model_inputs)
145
+ @forward.(model_inputs)
146
+ end
147
+
148
+ private
149
+
150
+ def encoder_forward(model_inputs)
151
+ encoder_feeds = {}
152
+ @session.inputs.each do |input|
153
+ key = input[:name].to_sym
154
+ encoder_feeds[key] = model_inputs[key]
155
+ end
156
+ if @session.inputs.any? { |v| v[:name] == "token_type_ids" } && !encoder_feeds[:token_type_ids]
157
+ raise Todo
158
+ end
159
+ session_run(@session, encoder_feeds)
160
+ end
161
+
162
+ def session_run(session, inputs)
163
+ checked_inputs = validate_inputs(session, inputs)
164
+ begin
165
+ output = session.run(@output_names, checked_inputs)
166
+ output = replace_tensors(output)
167
+ output
168
+ rescue => e
169
+ raise e
170
+ end
171
+ end
172
+
173
+ # TODO
174
+ def replace_tensors(obj)
175
+ obj
176
+ end
177
+
178
+ # TODO
179
+ def validate_inputs(session, inputs)
180
+ inputs
181
+ end
182
+ end
183
+
184
+ class BertPreTrainedModel < PreTrainedModel
185
+ end
186
+
187
+ class BertModel < BertPreTrainedModel
188
+ end
189
+
190
+ class BertForSequenceClassification < BertPreTrainedModel
191
+ def call(model_inputs)
192
+ SequenceClassifierOutput.new(*super(model_inputs))
193
+ end
194
+ end
195
+
196
+ class BertForTokenClassification < BertPreTrainedModel
197
+ def call(model_inputs)
198
+ TokenClassifierOutput.new(*super(model_inputs))
199
+ end
200
+ end
201
+
202
+ class DistilBertPreTrainedModel < PreTrainedModel
203
+ end
204
+
205
+ class DistilBertModel < DistilBertPreTrainedModel
206
+ end
207
+
208
+ class DistilBertForSequenceClassification < DistilBertPreTrainedModel
209
+ def call(model_inputs)
210
+ SequenceClassifierOutput.new(*super(model_inputs))
211
+ end
212
+ end
213
+
214
+ class DistilBertForQuestionAnswering < DistilBertPreTrainedModel
215
+ def call(model_inputs)
216
+ QuestionAnsweringModelOutput.new(*super(model_inputs))
217
+ end
218
+ end
219
+
220
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
221
+ "bert" => ["BertForSequenceClassification", BertForSequenceClassification],
222
+ "distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification]
223
+ }
224
+
225
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
226
+ "bert" => ["BertForTokenClassification", BertForTokenClassification]
227
+ }
228
+
229
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
230
+ "distilbert" => ["DistilBertForQuestionAnswering", DistilBertForQuestionAnswering]
231
+ }
232
+
233
+ MODEL_CLASS_TYPE_MAPPING = [
234
+ [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
235
+ [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
236
+ [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
237
+ ]
238
+
239
+ MODEL_CLASS_TYPE_MAPPING.each do |mappings, type|
240
+ mappings.values.each do |name, model|
241
+ MODEL_TYPE_MAPPING[name] = type
242
+ MODEL_CLASS_TO_NAME_MAPPING[model] = name
243
+ MODEL_NAME_TO_CLASS_MAPPING[name] = model
244
+ end
245
+ end
246
+
247
+ class AutoModel < PretrainedMixin
248
+ MODEL_CLASS_MAPPINGS = MODEL_CLASS_TYPE_MAPPING.map { |x| x[0] }
249
+ BASE_IF_FAIL = true
250
+ end
251
+
252
+ class AutoModelForSequenceClassification < PretrainedMixin
253
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES]
254
+ end
255
+
256
+ class AutoModelForTokenClassification < PretrainedMixin
257
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]
258
+ end
259
+
260
+ class AutoModelForQuestionAnswering < PretrainedMixin
261
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]
262
+ end
263
+
264
+ class ModelOutput
265
+ end
266
+
267
+ class SequenceClassifierOutput < ModelOutput
268
+ attr_reader :logits
269
+
270
+ def initialize(logits)
271
+ super()
272
+ @logits = logits
273
+ end
274
+ end
275
+
276
+ class TokenClassifierOutput < ModelOutput
277
+ attr_reader :logits
278
+
279
+ def initialize(logits)
280
+ super()
281
+ @logits = logits
282
+ end
283
+ end
284
+
285
+ class QuestionAnsweringModelOutput < ModelOutput
286
+ attr_reader :start_logits, :end_logits
287
+
288
+ def initialize(start_logits, end_logits)
289
+ super()
290
+ @start_logits = start_logits
291
+ @end_logits = end_logits
292
+ end
293
+ end
294
+ end