informers 0.2.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 22f7bcebf0670078b65fdf9cba4d2b937c853a3b10cf36e47f50781e2663225c
4
- data.tar.gz: 940c96ec6b749b7e0b0c283456e40bfe9e6cbb3a58e8fa11f6367e87b05d8694
3
+ metadata.gz: 3abc738d8975839b873bc5e07bb95305d455a9ac1eec94c432415b713411f20b
4
+ data.tar.gz: b9c36794c33316378752dd816fb517714c6d8186062562a778d3c8539ba7d79a
5
5
  SHA512:
6
- metadata.gz: 4cd8b58aae6e885409e297bc1ba09aedd029bb3dc26a193251f33c2bf6c9f6a8da69cb3727f799296a8c6644b014afc715e783a1e19a1074982af531e40db57b
7
- data.tar.gz: 6f63489d0b303e9a7de13df11d5074bd4cb2dfa44febee4061262d5c188eeb62a7c975e89567048f801fa183c8d56925275768fccc9a4b5a48255abeeb379345
6
+ metadata.gz: ce05bfcdebce333fd6b5abefca703850d3a6d6a50c3c1589bf675e91ae24b424f2e43e6bc0270ad4ea8a520f5be9d636c5e8a5a66deae2c0183adae6cbc517aa
7
+ data.tar.gz: 6cc9b08b6e0f9e8ea23f306c0c460dc2557e4ee5113ef26300b517608485ea528fcb9254d51f395c37b557bf1728051c2c3dd8a20a25b5bd4826832a4ff30bf8
data/CHANGELOG.md CHANGED
@@ -1,3 +1,14 @@
1
+ ## 1.0.1 (2024-08-27)
2
+
3
+ - Added support for `Supabase/gte-small` to `Model`
4
+ - Fixed error with downloads
5
+
6
+ ## 1.0.0 (2024-08-26)
7
+
8
+ - Replaced task classes with `pipeline` method
9
+ - Added `Model` class
10
+ - Dropped support for Ruby < 3.1
11
+
1
12
  ## 0.2.0 (2022-09-06)
2
13
 
3
14
  - Added support for `optimum` and `transformers.onnx` models
data/README.md CHANGED
@@ -1,15 +1,10 @@
1
1
  # Informers
2
2
 
3
- :slightly_smiling_face: State-of-the-art natural language processing for Ruby
3
+ :fire: Fast [transformer](https://github.com/xenova/transformers.js) inference for Ruby
4
4
 
5
- Supports:
5
+ For non-ONNX models, check out [Transformers.rb](https://github.com/ankane/transformers-ruby) :slightly_smiling_face:
6
6
 
7
- - Sentiment analysis
8
- - Question answering
9
- - Named-entity recognition
10
- - Text generation
11
-
12
- [![Build Status](https://github.com/ankane/informers/workflows/build/badge.svg?branch=master)](https://github.com/ankane/informers/actions)
7
+ [![Build Status](https://github.com/ankane/informers/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/informers/actions)
13
8
 
14
9
  ## Installation
15
10
 
@@ -21,140 +16,122 @@ gem "informers"
21
16
 
22
17
  ## Getting Started
23
18
 
24
- - [Sentiment analysis](#sentiment-analysis)
25
- - [Question answering](#question-answering)
26
- - [Named-entity recognition](#named-entity-recognition)
27
- - [Text generation](#text-generation)
28
- - [Feature extraction](#feature-extraction)
29
- - [Fill mask](#fill-mask)
19
+ - [Models](#models)
20
+ - [Pipelines](#pipelines)
30
21
 
31
- ### Sentiment Analysis
22
+ ## Models
32
23
 
33
- First, download the [pretrained model](https://github.com/ankane/informers/releases/download/v0.1.0/sentiment-analysis.onnx).
24
+ ### sentence-transformers/all-MiniLM-L6-v2
34
25
 
35
- Predict sentiment
26
+ [Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
36
27
 
37
28
  ```ruby
38
- model = Informers::SentimentAnalysis.new("sentiment-analysis.onnx")
39
- model.predict("This is super cool")
40
- ```
29
+ sentences = ["This is an example sentence", "Each sentence is converted"]
41
30
 
42
- This returns
43
-
44
- ```ruby
45
- {label: "positive", score: 0.999855186578301}
31
+ model = Informers::Model.new("sentence-transformers/all-MiniLM-L6-v2")
32
+ embeddings = model.embed(sentences)
46
33
  ```
47
34
 
48
- Predict multiple at once
35
+ For a quantized version, use:
49
36
 
50
37
  ```ruby
51
- model.predict(["This is super cool", "I didn't like it"])
38
+ model = Informers::Model.new("Xenova/all-MiniLM-L6-v2", quantized: true)
52
39
  ```
53
40
 
54
- ### Question Answering
41
+ ### Xenova/multi-qa-MiniLM-L6-cos-v1
55
42
 
56
- First, download the [pretrained model](https://github.com/ankane/informers/releases/download/v0.1.0/question-answering.onnx).
57
-
58
- Ask a question with some context
43
+ [Docs](https://huggingface.co/Xenova/multi-qa-MiniLM-L6-cos-v1)
59
44
 
60
45
  ```ruby
61
- model = Informers::QuestionAnswering.new("question-answering.onnx")
62
- model.predict(
63
- question: "Who invented Ruby?",
64
- context: "Ruby is a programming language created by Matz"
65
- )
46
+ query = "How many people live in London?"
47
+ docs = ["Around 9 Million people live in London", "London is known for its financial district"]
48
+
49
+ model = Informers::Model.new("Xenova/multi-qa-MiniLM-L6-cos-v1")
50
+ query_embedding = model.embed(query)
51
+ doc_embeddings = model.embed(docs)
52
+ scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
53
+ doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
66
54
  ```
67
55
 
68
- This returns
56
+ ### mixedbread-ai/mxbai-embed-large-v1
57
+
58
+ [Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
69
59
 
70
60
  ```ruby
71
- {answer: "Matz", score: 0.9980658360049758, start: 42, end: 46}
72
- ```
61
+ def transform_query(query)
62
+ "Represent this sentence for searching relevant passages: #{query}"
63
+ end
64
+
65
+ docs = [
66
+ transform_query("puppy"),
67
+ "The dog is barking",
68
+ "The cat is purring"
69
+ ]
73
70
 
74
- ### Named-Entity Recognition
71
+ model = Informers::Model.new("mixedbread-ai/mxbai-embed-large-v1")
72
+ embeddings = model.embed(docs)
73
+ ```
75
74
 
76
- First, export the [pretrained model](tools/export.md).
75
+ ### Supabase/gte-small
77
76
 
78
- Get entities
77
+ [Docs](https://huggingface.co/Supabase/gte-small)
79
78
 
80
79
  ```ruby
81
- model = Informers::NER.new("ner.onnx")
82
- model.predict("Nat works at GitHub in San Francisco")
83
- ```
80
+ sentences = ["That is a happy person", "That is a very happy person"]
84
81
 
85
- This returns
86
-
87
- ```ruby
88
- [
89
- {text: "Nat", tag: "person", score: 0.9840519576513487, start: 0, end: 3},
90
- {text: "GitHub", tag: "org", score: 0.9426134775785775, start: 13, end: 19},
91
- {text: "San Francisco", tag: "location", score: 0.9952414982243061, start: 23, end: 36}
92
- ]
82
+ model = Informers::Model.new("Supabase/gte-small")
83
+ embeddings = model.embed(sentences)
93
84
  ```
94
85
 
95
- ### Text Generation
96
-
97
- First, export the [pretrained model](tools/export.md).
86
+ ## Pipelines
98
87
 
99
- Pass a prompt
88
+ Named-entity recognition
100
89
 
101
90
  ```ruby
102
- model = Informers::TextGeneration.new("text-generation.onnx")
103
- model.predict("As far as I am concerned, I will", max_length: 50)
91
+ ner = Informers.pipeline("ner")
92
+ ner.("Ruby is a programming language created by Matz")
104
93
  ```
105
94
 
106
- This returns
95
+ Sentiment analysis
107
96
 
108
- ```text
109
- As far as I am concerned, I will be the first to admit that I am not a fan of the idea of a "free market." I think that the idea of a free market is a bit of a stretch. I think that the idea
97
+ ```ruby
98
+ classifier = Informers.pipeline("sentiment-analysis")
99
+ classifier.("We are very happy to show you the 🤗 Transformers library.")
110
100
  ```
111
101
 
112
- ### Feature Extraction
113
-
114
- First, export a [pretrained model](tools/export.md).
102
+ Question answering
115
103
 
116
104
  ```ruby
117
- model = Informers::FeatureExtraction.new("feature-extraction.onnx")
118
- model.predict("This is super cool")
105
+ qa = Informers.pipeline("question-answering")
106
+ qa.("Who invented Ruby?", "Ruby is a programming language created by Matz")
119
107
  ```
120
108
 
121
- ### Fill Mask
122
-
123
- First, export a [pretrained model](tools/export.md).
109
+ Feature extraction
124
110
 
125
111
  ```ruby
126
- model = Informers::FillMask.new("fill-mask.onnx")
127
- model.predict("This is a great <mask>")
112
+ extractor = Informers.pipeline("feature-extraction")
113
+ extractor.("We are very happy to show you the 🤗 Transformers library.")
128
114
  ```
129
115
 
130
- ## Models
131
-
132
- Task | Description | Contributor | License | Link
133
- --- | --- | --- | --- | ---
134
- Sentiment analysis | DistilBERT fine-tuned on SST-2 | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english)
135
- Question answering | DistilBERT fine-tuned on SQuAD | Hugging Face | Apache-2.0 | [Link](https://huggingface.co/distilbert-base-cased-distilled-squad)
136
- Named-entity recognition | BERT fine-tuned on CoNLL03 | Bayerische Staatsbibliothek | In-progress | [Link](https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english)
137
- Text generation | GPT-2 | OpenAI | [Custom](https://github.com/openai/gpt-2/blob/master/LICENSE) | [Link](https://huggingface.co/gpt2)
116
+ ## Credits
138
117
 
139
- Some models are [quantized](https://medium.com/microsoftazure/faster-and-smaller-quantized-nlp-with-hugging-face-and-onnx-runtime-ec5525473bb7) to make them faster and smaller.
118
+ This library was ported from [Transformers.js](https://github.com/xenova/transformers.js) and is available under the same license.
140
119
 
141
- ## Deployment
120
+ ## Upgrading
142
121
 
143
- Check out [Trove](https://github.com/ankane/trove) for deploying models.
122
+ ### 1.0
144
123
 
145
- ```sh
146
- trove push sentiment-analysis.onnx
147
- ```
124
+ Task classes have been replaced with the `pipeline` method.
148
125
 
149
- ## Credits
150
-
151
- This project uses many state-of-the-art technologies:
152
-
153
- - [Transformers](https://github.com/huggingface/transformers) for transformer models
154
- - [Bling Fire](https://github.com/microsoft/BlingFire) and [BERT](https://github.com/google-research/bert) for high-performance text tokenization
155
- - [ONNX Runtime](https://github.com/Microsoft/onnxruntime) for high-performance inference
126
+ ```ruby
127
+ # before
128
+ model = Informers::SentimentAnalysis.new("sentiment-analysis.onnx")
129
+ model.predict("This is super cool")
156
130
 
157
- Some code was ported from Transformers and is available under the same license.
131
+ # after
132
+ model = Informers.pipeline("sentiment-analysis")
133
+ model.("This is super cool")
134
+ ```
158
135
 
159
136
  ## History
160
137
 
@@ -175,7 +152,5 @@ To get started with development:
175
152
  git clone https://github.com/ankane/informers.git
176
153
  cd informers
177
154
  bundle install
178
-
179
- export MODELS_PATH=path/to/onnx/models
180
155
  bundle exec rake test
181
156
  ```
@@ -0,0 +1,48 @@
1
+ module Informers
2
+ class PretrainedConfig
3
+ attr_reader :model_type, :problem_type, :id2label
4
+
5
+ def initialize(config_json)
6
+ @is_encoder_decoder = false
7
+
8
+ @model_type = config_json["model_type"]
9
+ @problem_type = config_json["problem_type"]
10
+ @id2label = config_json["id2label"]
11
+ end
12
+
13
+ def [](key)
14
+ instance_variable_get("@#{key}")
15
+ end
16
+
17
+ def self.from_pretrained(
18
+ pretrained_model_name_or_path,
19
+ progress_callback: nil,
20
+ config: nil,
21
+ cache_dir: nil,
22
+ local_files_only: false,
23
+ revision: "main",
24
+ **kwargs
25
+ )
26
+ data = config || load_config(
27
+ pretrained_model_name_or_path,
28
+ progress_callback:,
29
+ config:,
30
+ cache_dir:,
31
+ local_files_only:,
32
+ revision:
33
+ )
34
+ new(data)
35
+ end
36
+
37
+ def self.load_config(pretrained_model_name_or_path, **options)
38
+ info = Utils::Hub.get_model_json(pretrained_model_name_or_path, "config.json", true, **options)
39
+ info
40
+ end
41
+ end
42
+
43
+ class AutoConfig
44
+ def self.from_pretrained(...)
45
+ PretrainedConfig.from_pretrained(...)
46
+ end
47
+ end
48
+ end
@@ -0,0 +1,14 @@
1
+ module Informers
2
+ CACHE_HOME = ENV.fetch("XDG_CACHE_HOME", File.join(ENV.fetch("HOME"), ".cache"))
3
+ DEFAULT_CACHE_DIR = File.expand_path(File.join(CACHE_HOME, "informers"))
4
+
5
+ class << self
6
+ attr_accessor :allow_remote_models, :remote_host, :remote_path_template, :cache_dir
7
+ end
8
+
9
+ self.allow_remote_models = ENV["INFORMERS_OFFLINE"].to_s.empty?
10
+ self.remote_host = "https://huggingface.co/"
11
+ self.remote_path_template = "{model}/resolve/{revision}/"
12
+
13
+ self.cache_dir = DEFAULT_CACHE_DIR
14
+ end
@@ -0,0 +1,31 @@
1
+ module Informers
2
+ class Model
3
+ def initialize(model_id, quantized: false)
4
+ @model_id = model_id
5
+ @model = Informers.pipeline("feature-extraction", model_id, quantized: quantized)
6
+
7
+ # TODO better pattern
8
+ if model_id == "sentence-transformers/all-MiniLM-L6-v2"
9
+ @model.instance_variable_get(:@model).instance_variable_set(:@output_names, ["sentence_embedding"])
10
+ end
11
+ end
12
+
13
+ def embed(texts)
14
+ is_batched = texts.is_a?(Array)
15
+ texts = [texts] unless is_batched
16
+
17
+ case @model_id
18
+ when "sentence-transformers/all-MiniLM-L6-v2"
19
+ output = @model.(texts)
20
+ when "Xenova/all-MiniLM-L6-v2", "Xenova/multi-qa-MiniLM-L6-cos-v1", "Supabase/gte-small"
21
+ output = @model.(texts, pooling: "mean", normalize: true)
22
+ when "mixedbread-ai/mxbai-embed-large-v1"
23
+ output = @model.(texts, pooling: "cls")
24
+ else
25
+ raise Error, "model not supported: #{@model_id}"
26
+ end
27
+
28
+ is_batched ? output : output[0]
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,294 @@
1
+ module Informers
2
+ MODEL_TYPES = {
3
+ EncoderOnly: 0,
4
+ EncoderDecoder: 1,
5
+ Seq2Seq: 2,
6
+ Vision2Seq: 3,
7
+ DecoderOnly: 4,
8
+ MaskGeneration: 5
9
+ }
10
+
11
+ # NOTE: These will be populated fully later
12
+ MODEL_TYPE_MAPPING = {}
13
+ MODEL_NAME_TO_CLASS_MAPPING = {}
14
+ MODEL_CLASS_TO_NAME_MAPPING = {}
15
+
16
+ class PretrainedMixin
17
+ def self.from_pretrained(
18
+ pretrained_model_name_or_path,
19
+ quantized: true,
20
+ progress_callback: nil,
21
+ config: nil,
22
+ cache_dir: nil,
23
+ local_files_only: false,
24
+ revision: "main",
25
+ model_file_name: nil
26
+ )
27
+ options = {
28
+ quantized:,
29
+ progress_callback:,
30
+ config:,
31
+ cache_dir:,
32
+ local_files_only:,
33
+ revision:,
34
+ model_file_name:
35
+ }
36
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **options)
37
+ if options[:config].nil?
38
+ # If no config was passed, reuse this config for future processing
39
+ options[:config] = config
40
+ end
41
+
42
+ if !const_defined?(:MODEL_CLASS_MAPPINGS)
43
+ raise Error, "`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: #{name}"
44
+ end
45
+
46
+ const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|
47
+ model_info = model_class_mapping[config.model_type]
48
+ if !model_info
49
+ next # Item not found in this mapping
50
+ end
51
+ return model_info[1].from_pretrained(pretrained_model_name_or_path, **options)
52
+ end
53
+
54
+ if const_defined?(:BASE_IF_FAIL)
55
+ warn "Unknown model class #{config.model_type.inspect}, attempting to construct from base class."
56
+ PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)
57
+ else
58
+ raise Error, "Unsupported model type: #{config.model_type}"
59
+ end
60
+ end
61
+ end
62
+
63
+ class PreTrainedModel
64
+ attr_reader :config
65
+
66
+ def initialize(config, session)
67
+ super()
68
+
69
+ @config = config
70
+ @session = session
71
+
72
+ @output_names = nil
73
+
74
+ model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
75
+ model_type = MODEL_TYPE_MAPPING[model_name]
76
+
77
+ case model_type
78
+ when MODEL_TYPES[:DecoderOnly]
79
+ raise Todo
80
+ when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
81
+ raise Todo
82
+ when MODEL_TYPES[:EncoderDecoder]
83
+ raise Todo
84
+ else
85
+ @forward = method(:encoder_forward)
86
+ end
87
+ end
88
+
89
+ def self.from_pretrained(
90
+ pretrained_model_name_or_path,
91
+ quantized: true,
92
+ progress_callback: nil,
93
+ config: nil,
94
+ cache_dir: nil,
95
+ local_files_only: false,
96
+ revision: "main",
97
+ model_file_name: nil
98
+ )
99
+ options = {
100
+ quantized:,
101
+ progress_callback:,
102
+ config:,
103
+ cache_dir:,
104
+ local_files_only:,
105
+ revision:,
106
+ model_file_name:
107
+ }
108
+
109
+ model_name = MODEL_CLASS_TO_NAME_MAPPING[self]
110
+ model_type = MODEL_TYPE_MAPPING[model_name]
111
+
112
+ if model_type == MODEL_TYPES[:DecoderOnly]
113
+ raise Todo
114
+
115
+ elsif model_type == MODEL_TYPES[:Seq2Seq] || model_type == MODEL_TYPES[:Vision2Seq]
116
+ raise Todo
117
+
118
+ elsif model_type == MODEL_TYPES[:MaskGeneration]
119
+ raise Todo
120
+
121
+ elsif model_type == MODEL_TYPES[:EncoderDecoder]
122
+ raise Todo
123
+
124
+ else
125
+ if model_type != MODEL_TYPES[:EncoderOnly]
126
+ warn "Model type for '#{model_name || config&.model_type}' not found, assuming encoder-only architecture. Please report this."
127
+ end
128
+ info = [
129
+ AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
130
+ construct_session(pretrained_model_name_or_path, options[:model_file_name] || "model", **options)
131
+ ]
132
+ end
133
+
134
+ new(*info)
135
+ end
136
+
137
+ def self.construct_session(pretrained_model_name_or_path, file_name, **options)
138
+ model_file_name = "onnx/#{file_name}#{options[:quantized] ? "_quantized" : ""}.onnx"
139
+ path = Utils::Hub.get_model_file(pretrained_model_name_or_path, model_file_name, true, **options)
140
+
141
+ OnnxRuntime::InferenceSession.new(path)
142
+ end
143
+
144
+ def call(model_inputs)
145
+ @forward.(model_inputs)
146
+ end
147
+
148
+ private
149
+
150
+ def encoder_forward(model_inputs)
151
+ encoder_feeds = {}
152
+ @session.inputs.each do |input|
153
+ key = input[:name].to_sym
154
+ encoder_feeds[key] = model_inputs[key]
155
+ end
156
+ if @session.inputs.any? { |v| v[:name] == "token_type_ids" } && !encoder_feeds[:token_type_ids]
157
+ raise Todo
158
+ end
159
+ session_run(@session, encoder_feeds)
160
+ end
161
+
162
+ def session_run(session, inputs)
163
+ checked_inputs = validate_inputs(session, inputs)
164
+ begin
165
+ output = session.run(@output_names, checked_inputs)
166
+ output = replace_tensors(output)
167
+ output
168
+ rescue => e
169
+ raise e
170
+ end
171
+ end
172
+
173
+ # TODO
174
+ def replace_tensors(obj)
175
+ obj
176
+ end
177
+
178
+ # TODO
179
+ def validate_inputs(session, inputs)
180
+ inputs
181
+ end
182
+ end
183
+
184
+ class BertPreTrainedModel < PreTrainedModel
185
+ end
186
+
187
+ class BertModel < BertPreTrainedModel
188
+ end
189
+
190
+ class BertForSequenceClassification < BertPreTrainedModel
191
+ def call(model_inputs)
192
+ SequenceClassifierOutput.new(*super(model_inputs))
193
+ end
194
+ end
195
+
196
+ class BertForTokenClassification < BertPreTrainedModel
197
+ def call(model_inputs)
198
+ TokenClassifierOutput.new(*super(model_inputs))
199
+ end
200
+ end
201
+
202
+ class DistilBertPreTrainedModel < PreTrainedModel
203
+ end
204
+
205
+ class DistilBertModel < DistilBertPreTrainedModel
206
+ end
207
+
208
+ class DistilBertForSequenceClassification < DistilBertPreTrainedModel
209
+ def call(model_inputs)
210
+ SequenceClassifierOutput.new(*super(model_inputs))
211
+ end
212
+ end
213
+
214
+ class DistilBertForQuestionAnswering < DistilBertPreTrainedModel
215
+ def call(model_inputs)
216
+ QuestionAnsweringModelOutput.new(*super(model_inputs))
217
+ end
218
+ end
219
+
220
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
221
+ "bert" => ["BertForSequenceClassification", BertForSequenceClassification],
222
+ "distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification]
223
+ }
224
+
225
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
226
+ "bert" => ["BertForTokenClassification", BertForTokenClassification]
227
+ }
228
+
229
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
230
+ "distilbert" => ["DistilBertForQuestionAnswering", DistilBertForQuestionAnswering]
231
+ }
232
+
233
+ MODEL_CLASS_TYPE_MAPPING = [
234
+ [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
235
+ [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
236
+ [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
237
+ ]
238
+
239
+ MODEL_CLASS_TYPE_MAPPING.each do |mappings, type|
240
+ mappings.values.each do |name, model|
241
+ MODEL_TYPE_MAPPING[name] = type
242
+ MODEL_CLASS_TO_NAME_MAPPING[model] = name
243
+ MODEL_NAME_TO_CLASS_MAPPING[name] = model
244
+ end
245
+ end
246
+
247
+ class AutoModel < PretrainedMixin
248
+ MODEL_CLASS_MAPPINGS = MODEL_CLASS_TYPE_MAPPING.map { |x| x[0] }
249
+ BASE_IF_FAIL = true
250
+ end
251
+
252
+ class AutoModelForSequenceClassification < PretrainedMixin
253
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES]
254
+ end
255
+
256
+ class AutoModelForTokenClassification < PretrainedMixin
257
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]
258
+ end
259
+
260
+ class AutoModelForQuestionAnswering < PretrainedMixin
261
+ MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]
262
+ end
263
+
264
+ class ModelOutput
265
+ end
266
+
267
+ class SequenceClassifierOutput < ModelOutput
268
+ attr_reader :logits
269
+
270
+ def initialize(logits)
271
+ super()
272
+ @logits = logits
273
+ end
274
+ end
275
+
276
+ class TokenClassifierOutput < ModelOutput
277
+ attr_reader :logits
278
+
279
+ def initialize(logits)
280
+ super()
281
+ @logits = logits
282
+ end
283
+ end
284
+
285
+ class QuestionAnsweringModelOutput < ModelOutput
286
+ attr_reader :start_logits, :end_logits
287
+
288
+ def initialize(start_logits, end_logits)
289
+ super()
290
+ @start_logits = start_logits
291
+ @end_logits = end_logits
292
+ end
293
+ end
294
+ end