transformers-rb 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,348 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
TASK_ALIASES = {
|
17
|
+
"sentiment-analysis" => "text-classification",
|
18
|
+
"ner" => "token-classification"
|
19
|
+
}
|
20
|
+
|
21
|
+
SUPPORTED_TASKS = {
|
22
|
+
"feature-extraction" => {
|
23
|
+
"impl" => FeatureExtractionPipeline,
|
24
|
+
"pt" => [AutoModel],
|
25
|
+
"default" => {
|
26
|
+
"model" => {
|
27
|
+
"pt" => ["distilbert/distilbert-base-cased", "935ac13"]
|
28
|
+
}
|
29
|
+
},
|
30
|
+
"type" => "multimodal"
|
31
|
+
},
|
32
|
+
"text-classification" => {
|
33
|
+
"impl" => TextClassificationPipeline,
|
34
|
+
"pt" => [AutoModelForSequenceClassification],
|
35
|
+
"default" => {
|
36
|
+
"model" => {
|
37
|
+
"pt" => ["distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"]
|
38
|
+
}
|
39
|
+
},
|
40
|
+
"type" => "text"
|
41
|
+
},
|
42
|
+
"token-classification" => {
|
43
|
+
"impl" => TokenClassificationPipeline,
|
44
|
+
"pt" => [AutoModelForTokenClassification],
|
45
|
+
"default" => {
|
46
|
+
"model" => {
|
47
|
+
"pt" => ["dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"]
|
48
|
+
}
|
49
|
+
},
|
50
|
+
"type" => "text"
|
51
|
+
},
|
52
|
+
"question-answering" => {
|
53
|
+
"impl" => QuestionAnsweringPipeline,
|
54
|
+
"pt" => [AutoModelForQuestionAnswering],
|
55
|
+
"default" => {
|
56
|
+
"model" => {
|
57
|
+
"pt" => ["distilbert/distilbert-base-cased-distilled-squad", "626af31"]
|
58
|
+
}
|
59
|
+
},
|
60
|
+
"type" => "text"
|
61
|
+
},
|
62
|
+
"image-classification" => {
|
63
|
+
"impl" => ImageClassificationPipeline,
|
64
|
+
"pt" => [AutoModelForImageClassification],
|
65
|
+
"default" => {
|
66
|
+
"model" => {
|
67
|
+
"pt" => ["google/vit-base-patch16-224", "5dca96d"]
|
68
|
+
}
|
69
|
+
},
|
70
|
+
"type" => "image"
|
71
|
+
},
|
72
|
+
"image-feature-extraction" => {
|
73
|
+
"impl" => ImageFeatureExtractionPipeline,
|
74
|
+
"pt" => [AutoModel],
|
75
|
+
"default" => {
|
76
|
+
"model" => {
|
77
|
+
"pt" => ["google/vit-base-patch16-224", "3f49326"]
|
78
|
+
}
|
79
|
+
},
|
80
|
+
"type" => "image"
|
81
|
+
}
|
82
|
+
}
|
83
|
+
|
84
|
+
PIPELINE_REGISTRY = PipelineRegistry.new(supported_tasks: SUPPORTED_TASKS, task_aliases: TASK_ALIASES)
|
85
|
+
|
86
|
+
class << self
|
87
|
+
def pipeline(
|
88
|
+
task,
|
89
|
+
model: nil,
|
90
|
+
config: nil,
|
91
|
+
tokenizer: nil,
|
92
|
+
feature_extractor: nil,
|
93
|
+
image_processor: nil,
|
94
|
+
framework: nil,
|
95
|
+
revision: nil,
|
96
|
+
use_fast: true,
|
97
|
+
token: nil,
|
98
|
+
device: nil,
|
99
|
+
device_map: nil,
|
100
|
+
torch_dtype: nil,
|
101
|
+
trust_remote_code: nil,
|
102
|
+
model_kwargs: nil,
|
103
|
+
pipeline_class: nil,
|
104
|
+
**kwargs
|
105
|
+
)
|
106
|
+
model_kwargs ||= {}
|
107
|
+
# Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
|
108
|
+
# this is to keep BC).
|
109
|
+
use_auth_token = model_kwargs.delete(:use_auth_token)
|
110
|
+
if !use_auth_token.nil?
|
111
|
+
raise Todo
|
112
|
+
end
|
113
|
+
|
114
|
+
code_revision = kwargs.delete(:code_revision)
|
115
|
+
commit_hash = kwargs.delete(:_commit_hash)
|
116
|
+
|
117
|
+
hub_kwargs = {
|
118
|
+
revision: revision,
|
119
|
+
token: token,
|
120
|
+
trust_remote_code: trust_remote_code,
|
121
|
+
_commit_hash: commit_hash
|
122
|
+
}
|
123
|
+
|
124
|
+
if task.nil? && model.nil?
|
125
|
+
raise RuntimeError,
|
126
|
+
"Impossible to instantiate a pipeline without either a task or a model " +
|
127
|
+
"being specified. " +
|
128
|
+
"Please provide a task class or a model"
|
129
|
+
end
|
130
|
+
|
131
|
+
if model.nil? && !tokenizer.nil?
|
132
|
+
raise RuntimeError,
|
133
|
+
"Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided tokenizer" +
|
134
|
+
" may not be compatible with the default model. Please provide a PreTrainedModel class or a" +
|
135
|
+
" path/identifier to a pretrained model when providing tokenizer."
|
136
|
+
end
|
137
|
+
if model.nil? && !feature_extractor.nil?
|
138
|
+
raise RuntimeError,
|
139
|
+
"Impossible to instantiate a pipeline with feature_extractor specified but not the model as the provided" +
|
140
|
+
" feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class" +
|
141
|
+
" or a path/identifier to a pretrained model when providing feature_extractor."
|
142
|
+
end
|
143
|
+
if model.is_a?(Pathname)
|
144
|
+
model = model.to_s
|
145
|
+
end
|
146
|
+
|
147
|
+
if commit_hash.nil?
|
148
|
+
pretrained_model_name_or_path = nil
|
149
|
+
if config.is_a?(String)
|
150
|
+
pretrained_model_name_or_path = config
|
151
|
+
elsif config.nil? && model.is_a?(String)
|
152
|
+
pretrained_model_name_or_path = model
|
153
|
+
end
|
154
|
+
|
155
|
+
if !config.is_a?(PretrainedConfig) && !pretrained_model_name_or_path.nil?
|
156
|
+
# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
|
157
|
+
resolved_config_file = Utils::Hub.cached_file(
|
158
|
+
pretrained_model_name_or_path,
|
159
|
+
CONFIG_NAME,
|
160
|
+
_raise_exceptions_for_gated_repo: false,
|
161
|
+
_raise_exceptions_for_missing_entries: false,
|
162
|
+
_raise_exceptions_for_connection_errors: false,
|
163
|
+
cache_dir: model_kwargs[:cache_dir],
|
164
|
+
**hub_kwargs
|
165
|
+
)
|
166
|
+
hub_kwargs[:_commit_hash] = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
|
167
|
+
else
|
168
|
+
hub_kwargs[:_commit_hash] = nil # getattr(config, "_commit_hash", None)
|
169
|
+
end
|
170
|
+
end
|
171
|
+
|
172
|
+
# Config is the primordial information item.
|
173
|
+
# Instantiate config if needed
|
174
|
+
if config.is_a?(String)
|
175
|
+
raise Todo
|
176
|
+
elsif config.nil? && model.is_a?(String)
|
177
|
+
config = AutoConfig.from_pretrained(
|
178
|
+
model, _from_pipeline: task, code_revision: code_revision, **hub_kwargs, **model_kwargs
|
179
|
+
)
|
180
|
+
hub_kwargs[:_commit_hash] = config._commit_hash
|
181
|
+
end
|
182
|
+
|
183
|
+
custom_tasks = {}
|
184
|
+
if !config.nil? && (config.instance_variable_get(:@custom_pipelines) || {}).length > 0
|
185
|
+
raise Todo
|
186
|
+
end
|
187
|
+
|
188
|
+
if task.nil? && !model.nil?
|
189
|
+
raise Todo
|
190
|
+
end
|
191
|
+
|
192
|
+
# Retrieve the task
|
193
|
+
if custom_tasks.include?(task)
|
194
|
+
raise Todo
|
195
|
+
else
|
196
|
+
_normalized_task, targeted_task, task_options = check_task(task)
|
197
|
+
if pipeline_class.nil?
|
198
|
+
pipeline_class = targeted_task["impl"]
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
# Use default model/config/tokenizer for the task if no model is provided
|
203
|
+
if model.nil?
|
204
|
+
# At that point framework might still be undetermined
|
205
|
+
model, default_revision = Pipelines.get_default_model_and_revision(targeted_task, framework, task_options)
|
206
|
+
revision = !revision.nil? ? revision : default_revision
|
207
|
+
Transformers.logger.warn(
|
208
|
+
"No model was supplied, defaulted to #{model} and revision" +
|
209
|
+
" #{revision} (#{Utils::Hub::HUGGINGFACE_CO_RESOLVE_ENDPOINT}/#{model}).\n" +
|
210
|
+
"Using a pipeline without specifying a model name and revision in production is not recommended."
|
211
|
+
)
|
212
|
+
if config.nil? && model.is_a?(String)
|
213
|
+
config = AutoConfig.from_pretrained(model, _from_pipeline: task, **hub_kwargs, **model_kwargs)
|
214
|
+
hub_kwargs[:_commit_hash] = config._commit_hash
|
215
|
+
end
|
216
|
+
end
|
217
|
+
|
218
|
+
if !device_map.nil?
|
219
|
+
raise Todo
|
220
|
+
end
|
221
|
+
if !torch_dtype.nil?
|
222
|
+
raise Todo
|
223
|
+
end
|
224
|
+
|
225
|
+
model_name = model.is_a?(String) ? model : nil
|
226
|
+
|
227
|
+
# Load the correct model if possible
|
228
|
+
# Infer the framework from the model if not already defined
|
229
|
+
if model.is_a?(String) || framework.nil?
|
230
|
+
model_classes = {"tf" => targeted_task["tf"], "pt" => targeted_task["pt"]}
|
231
|
+
framework, model =
|
232
|
+
Pipelines.infer_framework_load_model(
|
233
|
+
model,
|
234
|
+
config,
|
235
|
+
model_classes: model_classes,
|
236
|
+
framework: framework,
|
237
|
+
task: task,
|
238
|
+
**hub_kwargs,
|
239
|
+
**model_kwargs
|
240
|
+
)
|
241
|
+
end
|
242
|
+
|
243
|
+
model_config = model.config
|
244
|
+
hub_kwargs[:_commit_hash] = model.config._commit_hash
|
245
|
+
model_config_type = model_config.class.name.split("::").last
|
246
|
+
load_tokenizer = TOKENIZER_MAPPING.include?(model_config_type) || !model_config.tokenizer_class.nil?
|
247
|
+
load_feature_extractor = FEATURE_EXTRACTOR_MAPPING.include?(model_config_type) || !feature_extractor.nil?
|
248
|
+
load_image_processor = IMAGE_PROCESSOR_MAPPING.include?(model_config_type) || !image_processor.nil?
|
249
|
+
|
250
|
+
if load_tokenizer
|
251
|
+
# Try to infer tokenizer from model or config name (if provided as str)
|
252
|
+
if tokenizer.nil?
|
253
|
+
if model_name.is_a?(String)
|
254
|
+
tokenizer = model_name
|
255
|
+
elsif config.is_a?(String)
|
256
|
+
tokenizer = config
|
257
|
+
else
|
258
|
+
# Impossible to guess what is the right tokenizer here
|
259
|
+
raise "Impossible to guess which tokenizer to use. Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
|
260
|
+
end
|
261
|
+
end
|
262
|
+
|
263
|
+
# Instantiate tokenizer if needed
|
264
|
+
if tokenizer.is_a?(String) || tokenizer.is_a?(Array)
|
265
|
+
if tokenizer.is_a?(Array)
|
266
|
+
# For array we have [tokenizer name, {kwargs}]
|
267
|
+
use_fast = tokenizer[1].delete(:use_fast) { use_fast }
|
268
|
+
tokenizer_identifier = tokenizer[0]
|
269
|
+
tokenizer_kwargs = tokenizer[1]
|
270
|
+
else
|
271
|
+
tokenizer_identifier = tokenizer
|
272
|
+
tokenizer_kwargs = model_kwargs.dup
|
273
|
+
tokenizer_kwargs.delete(:torch_dtype)
|
274
|
+
end
|
275
|
+
|
276
|
+
tokenizer =
|
277
|
+
AutoTokenizer.from_pretrained(
|
278
|
+
tokenizer_identifier, use_fast: use_fast, _from_pipeline: task, **hub_kwargs, **tokenizer_kwargs
|
279
|
+
)
|
280
|
+
end
|
281
|
+
end
|
282
|
+
|
283
|
+
if load_image_processor
|
284
|
+
# Try to infer image processor from model or config name (if provided as str)
|
285
|
+
if image_processor.nil?
|
286
|
+
if model_name.is_a?(String)
|
287
|
+
image_processor = model_name
|
288
|
+
elsif config.is_a?(String)
|
289
|
+
image_processor = config
|
290
|
+
# Backward compatibility, as `feature_extractor` used to be the name
|
291
|
+
# for `ImageProcessor`.
|
292
|
+
elsif !feature_extractor.nil? && feature_extractor.is_a?(BaseImageProcessor)
|
293
|
+
image_processor = feature_extractor
|
294
|
+
else
|
295
|
+
# Impossible to guess what is the right image_processor here
|
296
|
+
raise RuntimeError,
|
297
|
+
"Impossible to guess which image processor to use. " +
|
298
|
+
"Please provide a PreTrainedImageProcessor class or a path/identifier " +
|
299
|
+
"to a pretrained image processor."
|
300
|
+
end
|
301
|
+
end
|
302
|
+
|
303
|
+
# Instantiate image_processor if needed
|
304
|
+
if image_processor.is_a?(String) || image_processor.is_a?(Array)
|
305
|
+
image_processor = AutoImageProcessor.from_pretrained(
|
306
|
+
image_processor, _from_pipeline: task, **hub_kwargs, **model_kwargs
|
307
|
+
)
|
308
|
+
end
|
309
|
+
end
|
310
|
+
|
311
|
+
if load_feature_extractor
|
312
|
+
raise Todo
|
313
|
+
end
|
314
|
+
|
315
|
+
if task == "translation" && model.config.task_specific_params
|
316
|
+
raise Todo
|
317
|
+
end
|
318
|
+
|
319
|
+
if !tokenizer.nil?
|
320
|
+
kwargs[:tokenizer] = tokenizer
|
321
|
+
end
|
322
|
+
|
323
|
+
if !feature_extractor.nil?
|
324
|
+
kwargs[:feature_extractor] = feature_extractor
|
325
|
+
end
|
326
|
+
|
327
|
+
if !torch_dtype.nil?
|
328
|
+
kwargs[:torch_dtype] = torch_dtype
|
329
|
+
end
|
330
|
+
|
331
|
+
if !image_processor.nil?
|
332
|
+
kwargs[:image_processor] = image_processor
|
333
|
+
end
|
334
|
+
|
335
|
+
if !device.nil?
|
336
|
+
kwargs[:device] = device
|
337
|
+
end
|
338
|
+
|
339
|
+
pipeline_class.new(model, framework: framework, task: task, **kwargs)
|
340
|
+
end
|
341
|
+
|
342
|
+
private
|
343
|
+
|
344
|
+
def check_task(task)
|
345
|
+
PIPELINE_REGISTRY.check_task(task)
|
346
|
+
end
|
347
|
+
end
|
348
|
+
end
|
@@ -0,0 +1,301 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module Pipelines
|
17
|
+
def self.get_default_model_and_revision(targeted_task, framework, task_options)
|
18
|
+
defaults = targeted_task["default"]
|
19
|
+
|
20
|
+
if defaults.key?("model")
|
21
|
+
default_models = targeted_task["default"]["model"]
|
22
|
+
end
|
23
|
+
|
24
|
+
if framework.nil?
|
25
|
+
framework = "pt"
|
26
|
+
end
|
27
|
+
|
28
|
+
default_models[framework]
|
29
|
+
end
|
30
|
+
|
31
|
+
def self.infer_framework_load_model(
|
32
|
+
model,
|
33
|
+
config,
|
34
|
+
model_classes: nil,
|
35
|
+
task: nil,
|
36
|
+
framework: nil,
|
37
|
+
**model_kwargs
|
38
|
+
)
|
39
|
+
if model.is_a?(String)
|
40
|
+
model_kwargs[:_from_pipeline] = task
|
41
|
+
class_tuple = []
|
42
|
+
look_pt = true
|
43
|
+
|
44
|
+
if model_classes
|
45
|
+
if look_pt
|
46
|
+
class_tuple = class_tuple + model_classes.fetch("pt", AutoModel)
|
47
|
+
end
|
48
|
+
end
|
49
|
+
if config.architectures
|
50
|
+
classes = []
|
51
|
+
config.architectures.each do |architecture|
|
52
|
+
if look_pt
|
53
|
+
_class = Transformers.const_get(architecture)
|
54
|
+
if !_class.nil?
|
55
|
+
classes << _class
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
class_tuple = class_tuple + classes
|
60
|
+
end
|
61
|
+
|
62
|
+
if class_tuple.length == 0
|
63
|
+
raise ArgumentError, "Pipeline cannot infer suitable model classes from #{model}"
|
64
|
+
end
|
65
|
+
|
66
|
+
class_tuple.each do |model_class|
|
67
|
+
raise Error, "Invalid auto model class: #{model_class}" unless model_class < BaseAutoModelClass
|
68
|
+
kwargs = model_kwargs.dup
|
69
|
+
|
70
|
+
begin
|
71
|
+
model = model_class.from_pretrained(model, **kwargs)
|
72
|
+
if model.respond_to?(:eval)
|
73
|
+
model = model.eval
|
74
|
+
end
|
75
|
+
break
|
76
|
+
rescue
|
77
|
+
# TODO
|
78
|
+
raise
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
if framework.nil?
|
84
|
+
framework = Utils.infer_framework(model.class)
|
85
|
+
end
|
86
|
+
[framework, model]
|
87
|
+
end
|
88
|
+
end
|
89
|
+
|
90
|
+
class ArgumentHandler
|
91
|
+
end
|
92
|
+
|
93
|
+
class Pipeline
|
94
|
+
def initialize(
|
95
|
+
model,
|
96
|
+
tokenizer: nil,
|
97
|
+
feature_extractor: nil,
|
98
|
+
image_processor: nil,
|
99
|
+
modelcard: nil,
|
100
|
+
framework: nil,
|
101
|
+
task: "",
|
102
|
+
device: nil,
|
103
|
+
**kwargs
|
104
|
+
)
|
105
|
+
if framework.nil?
|
106
|
+
raise Todo
|
107
|
+
end
|
108
|
+
|
109
|
+
@task = task
|
110
|
+
@model = model
|
111
|
+
@tokenizer = tokenizer
|
112
|
+
@feature_extractor = feature_extractor
|
113
|
+
@image_processor = image_processor
|
114
|
+
@modelcard = modelcard
|
115
|
+
@framework = framework
|
116
|
+
|
117
|
+
if device.nil?
|
118
|
+
if Torch::CUDA.available? || Torch::Backends::MPS.available?
|
119
|
+
Transformers.logger.warn(
|
120
|
+
"Hardware accelerator e.g. GPU is available in the environment, but no `device` argument" +
|
121
|
+
" is passed to the `Pipeline` object. Model will be on CPU."
|
122
|
+
)
|
123
|
+
end
|
124
|
+
end
|
125
|
+
|
126
|
+
@call_count = 0
|
127
|
+
@batch_size = kwargs.delete(:batch_size)
|
128
|
+
@num_workers = kwargs.delete(:num_workers)
|
129
|
+
@preprocess_params, @forward_params, @postprocess_params = _sanitize_parameters(**kwargs)
|
130
|
+
end
|
131
|
+
|
132
|
+
def torch_dtype
|
133
|
+
@model.dtype
|
134
|
+
end
|
135
|
+
|
136
|
+
def check_model_type(supported_models)
|
137
|
+
if !supported_models.is_a?(Array)
|
138
|
+
supported_models_names = []
|
139
|
+
supported_models.each do |_, model_name|
|
140
|
+
# Mapping can now contain tuples of models for the same configuration.
|
141
|
+
if model_name.is_a?(Array)
|
142
|
+
supported_models_names.concat(model_name)
|
143
|
+
else
|
144
|
+
supported_models_names << model_name
|
145
|
+
end
|
146
|
+
end
|
147
|
+
supported_models = supported_models_names
|
148
|
+
end
|
149
|
+
if !supported_models.include?(@model.class.name.split("::").last)
|
150
|
+
Transformers.logger.error(
|
151
|
+
"The model '#{@model.class.name}' is not supported for #{@task}. Supported models are" +
|
152
|
+
" #{supported_models}."
|
153
|
+
)
|
154
|
+
end
|
155
|
+
end
|
156
|
+
|
157
|
+
def get_iterator(
|
158
|
+
inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
|
159
|
+
)
|
160
|
+
if inputs.respond_to?(:size)
|
161
|
+
dataset = PipelineDataset.new(inputs, method(:preprocess), preprocess_params)
|
162
|
+
else
|
163
|
+
if num_workers > 1
|
164
|
+
Transformers.logger.warn(
|
165
|
+
"For iterable dataset using num_workers>1 is likely to result" +
|
166
|
+
" in errors since everything is iterable, setting `num_workers: 1`" +
|
167
|
+
" to guarantee correctness."
|
168
|
+
)
|
169
|
+
num_workers = 1
|
170
|
+
end
|
171
|
+
dataset = PipelineIterator.new(inputs, method(:preprocess), preprocess_params)
|
172
|
+
end
|
173
|
+
|
174
|
+
# TODO hack by collating feature_extractor and image_processor
|
175
|
+
feature_extractor = !@feature_extractor.nil? ? @feature_extractor : @image_processor
|
176
|
+
collate_fn = batch_size == 1 ? method(:no_collate_fn) : pad_collate_fn(@tokenizer, feature_extractor)
|
177
|
+
dataloader = Torch::Utils::Data::DataLoader.new(dataset, batch_size: batch_size, collate_fn: collate_fn) # num_workers: num_workers,
|
178
|
+
model_iterator = PipelineIterator.new(dataloader, method(:forward), forward_params, loader_batch_size: batch_size)
|
179
|
+
final_iterator = PipelineIterator.new(model_iterator, method(:postprocess), postprocess_params)
|
180
|
+
final_iterator
|
181
|
+
end
|
182
|
+
|
183
|
+
def call(inputs, *args, num_workers: nil, batch_size: nil, **kwargs)
|
184
|
+
if args.any?
|
185
|
+
Transformers.logger.warn("Ignoring args : #{args}")
|
186
|
+
end
|
187
|
+
|
188
|
+
if num_workers.nil?
|
189
|
+
if @num_workers.nil?
|
190
|
+
num_workers = 0
|
191
|
+
else
|
192
|
+
num_workers = @num_workers
|
193
|
+
end
|
194
|
+
end
|
195
|
+
if batch_size.nil?
|
196
|
+
if @batch_size.nil?
|
197
|
+
batch_size = 1
|
198
|
+
else
|
199
|
+
batch_size = @batch_size
|
200
|
+
end
|
201
|
+
end
|
202
|
+
|
203
|
+
preprocess_params, forward_params, postprocess_params = _sanitize_parameters(**kwargs)
|
204
|
+
|
205
|
+
preprocess_params = @preprocess_params.merge(preprocess_params)
|
206
|
+
forward_params = @forward_params.merge(forward_params)
|
207
|
+
postprocess_params = @postprocess_params.merge(postprocess_params)
|
208
|
+
|
209
|
+
@call_count += 1
|
210
|
+
if @call_count > 10 && @framework == "pt" && @device.type == "cuda"
|
211
|
+
Transformers.logger.warn(
|
212
|
+
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a" +
|
213
|
+
" dataset"
|
214
|
+
)
|
215
|
+
end
|
216
|
+
|
217
|
+
is_dataset = inputs.is_a?(Torch::Utils::Data::Dataset)
|
218
|
+
is_generator = inputs.is_a?(Enumerable)
|
219
|
+
is_list = inputs.is_a?(Array)
|
220
|
+
|
221
|
+
_is_iterable = is_dataset || is_generator || is_list
|
222
|
+
|
223
|
+
# TODO make the get_iterator work also for `tf` (and `flax`).
|
224
|
+
can_use_iterator = @framework == "pt" && (is_dataset || is_generator || is_list)
|
225
|
+
|
226
|
+
if is_list
|
227
|
+
if can_use_iterator
|
228
|
+
final_iterator = get_iterator(
|
229
|
+
inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
|
230
|
+
)
|
231
|
+
outputs = final_iterator.to_a
|
232
|
+
outputs
|
233
|
+
else
|
234
|
+
run_multi(inputs, preprocess_params, forward_params, postprocess_params)
|
235
|
+
end
|
236
|
+
else
|
237
|
+
run_single(inputs, preprocess_params, forward_params, postprocess_params)
|
238
|
+
end
|
239
|
+
end
|
240
|
+
|
241
|
+
private
|
242
|
+
|
243
|
+
def _sanitize_parameters(**kwargs)
|
244
|
+
raise NotImplementedError, "_sanitize_parameters not implemented"
|
245
|
+
end
|
246
|
+
|
247
|
+
def forward(model_inputs, **forward_params)
|
248
|
+
_forward(model_inputs, **forward_params)
|
249
|
+
end
|
250
|
+
|
251
|
+
def run_single(inputs, preprocess_params, forward_params, postprocess_params)
|
252
|
+
model_inputs = preprocess(inputs, **preprocess_params)
|
253
|
+
model_outputs = forward(model_inputs, **forward_params)
|
254
|
+
outputs = postprocess(model_outputs, **postprocess_params)
|
255
|
+
outputs
|
256
|
+
end
|
257
|
+
|
258
|
+
def no_collate_fn(items)
|
259
|
+
if items.length != 1
|
260
|
+
raise ArgumentError, "This collate_fn is meant to be used with batch_size=1"
|
261
|
+
end
|
262
|
+
items[0]
|
263
|
+
end
|
264
|
+
end
|
265
|
+
|
266
|
+
class ChunkPipeline < Pipeline
|
267
|
+
def run_single(inputs, preprocess_params, forward_params, postprocess_params)
|
268
|
+
all_outputs = []
|
269
|
+
preprocess(inputs, **preprocess_params) do |model_inputs|
|
270
|
+
model_outputs = forward(model_inputs, **forward_params)
|
271
|
+
all_outputs << model_outputs
|
272
|
+
end
|
273
|
+
outputs = postprocess(all_outputs, **postprocess_params)
|
274
|
+
outputs
|
275
|
+
end
|
276
|
+
end
|
277
|
+
|
278
|
+
class PipelineRegistry
|
279
|
+
def initialize(supported_tasks:, task_aliases:)
|
280
|
+
@supported_tasks = supported_tasks
|
281
|
+
@task_aliases = task_aliases
|
282
|
+
end
|
283
|
+
|
284
|
+
def get_supported_tasks
|
285
|
+
supported_task = @supported_tasks.keys + @task_aliases.keys
|
286
|
+
supported_task.sort
|
287
|
+
end
|
288
|
+
|
289
|
+
def check_task(task)
|
290
|
+
if @task_aliases[task]
|
291
|
+
task = @task_aliases[task]
|
292
|
+
end
|
293
|
+
if @supported_tasks[task]
|
294
|
+
targeted_task = @supported_tasks[task]
|
295
|
+
return task, targeted_task, nil
|
296
|
+
end
|
297
|
+
|
298
|
+
raise KeyError, "Unknown task #{task}, available tasks are #{get_supported_tasks}"
|
299
|
+
end
|
300
|
+
end
|
301
|
+
end
|
@@ -0,0 +1,47 @@
|
|
1
|
+
module Transformers
|
2
|
+
class FeatureExtractionPipeline < Pipeline
|
3
|
+
def _sanitize_parameters(truncation: nil, tokenize_kwargs: nil, return_tensors: nil, **kwargs)
|
4
|
+
if tokenize_kwargs.nil?
|
5
|
+
tokenize_kwargs = {}
|
6
|
+
end
|
7
|
+
|
8
|
+
if !truncation.nil?
|
9
|
+
if tokenize_kwargs.include?(:truncation)
|
10
|
+
raise ArgumentError,
|
11
|
+
"truncation parameter defined twice (given as keyword argument as well as in tokenize_kwargs)"
|
12
|
+
end
|
13
|
+
tokenize_kwargs[:truncation] = truncation
|
14
|
+
end
|
15
|
+
|
16
|
+
preprocess_params = tokenize_kwargs
|
17
|
+
|
18
|
+
postprocess_params = {}
|
19
|
+
if !return_tensors.nil?
|
20
|
+
postprocess_params[:return_tensors] = return_tensors
|
21
|
+
end
|
22
|
+
|
23
|
+
[preprocess_params, {}, postprocess_params]
|
24
|
+
end
|
25
|
+
|
26
|
+
def preprocess(inputs, **tokenize_kwargs)
|
27
|
+
model_inputs = @tokenizer.(inputs, return_tensors: @framework, **tokenize_kwargs)
|
28
|
+
model_inputs
|
29
|
+
end
|
30
|
+
|
31
|
+
def _forward(model_inputs)
|
32
|
+
model_outputs = @model.(**model_inputs)
|
33
|
+
model_outputs
|
34
|
+
end
|
35
|
+
|
36
|
+
def postprocess(model_outputs, return_tensors: false)
|
37
|
+
# [0] is the first available tensor, logits or last_hidden_state.
|
38
|
+
if return_tensors
|
39
|
+
model_outputs[0]
|
40
|
+
elsif @framework == "pt"
|
41
|
+
model_outputs[0].to_a
|
42
|
+
elsif @framework == "tf"
|
43
|
+
raise Todo
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|