transformers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,348 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
TASK_ALIASES = {
|
17
|
+
"sentiment-analysis" => "text-classification",
|
18
|
+
"ner" => "token-classification"
|
19
|
+
}
|
20
|
+
|
21
|
+
SUPPORTED_TASKS = {
|
22
|
+
"feature-extraction" => {
|
23
|
+
"impl" => FeatureExtractionPipeline,
|
24
|
+
"pt" => [AutoModel],
|
25
|
+
"default" => {
|
26
|
+
"model" => {
|
27
|
+
"pt" => ["distilbert/distilbert-base-cased", "935ac13"]
|
28
|
+
}
|
29
|
+
},
|
30
|
+
"type" => "multimodal"
|
31
|
+
},
|
32
|
+
"text-classification" => {
|
33
|
+
"impl" => TextClassificationPipeline,
|
34
|
+
"pt" => [AutoModelForSequenceClassification],
|
35
|
+
"default" => {
|
36
|
+
"model" => {
|
37
|
+
"pt" => ["distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"]
|
38
|
+
}
|
39
|
+
},
|
40
|
+
"type" => "text"
|
41
|
+
},
|
42
|
+
"token-classification" => {
|
43
|
+
"impl" => TokenClassificationPipeline,
|
44
|
+
"pt" => [AutoModelForTokenClassification],
|
45
|
+
"default" => {
|
46
|
+
"model" => {
|
47
|
+
"pt" => ["dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"]
|
48
|
+
}
|
49
|
+
},
|
50
|
+
"type" => "text"
|
51
|
+
},
|
52
|
+
"question-answering" => {
|
53
|
+
"impl" => QuestionAnsweringPipeline,
|
54
|
+
"pt" => [AutoModelForQuestionAnswering],
|
55
|
+
"default" => {
|
56
|
+
"model" => {
|
57
|
+
"pt" => ["distilbert/distilbert-base-cased-distilled-squad", "626af31"]
|
58
|
+
}
|
59
|
+
},
|
60
|
+
"type" => "text"
|
61
|
+
},
|
62
|
+
"image-classification" => {
|
63
|
+
"impl" => ImageClassificationPipeline,
|
64
|
+
"pt" => [AutoModelForImageClassification],
|
65
|
+
"default" => {
|
66
|
+
"model" => {
|
67
|
+
"pt" => ["google/vit-base-patch16-224", "5dca96d"]
|
68
|
+
}
|
69
|
+
},
|
70
|
+
"type" => "image"
|
71
|
+
},
|
72
|
+
"image-feature-extraction" => {
|
73
|
+
"impl" => ImageFeatureExtractionPipeline,
|
74
|
+
"pt" => [AutoModel],
|
75
|
+
"default" => {
|
76
|
+
"model" => {
|
77
|
+
"pt" => ["google/vit-base-patch16-224", "3f49326"]
|
78
|
+
}
|
79
|
+
},
|
80
|
+
"type" => "image"
|
81
|
+
}
|
82
|
+
}
|
83
|
+
|
84
|
+
PIPELINE_REGISTRY = PipelineRegistry.new(supported_tasks: SUPPORTED_TASKS, task_aliases: TASK_ALIASES)
|
85
|
+
|
86
|
+
class << self
|
87
|
+
def pipeline(
|
88
|
+
task,
|
89
|
+
model: nil,
|
90
|
+
config: nil,
|
91
|
+
tokenizer: nil,
|
92
|
+
feature_extractor: nil,
|
93
|
+
image_processor: nil,
|
94
|
+
framework: nil,
|
95
|
+
revision: nil,
|
96
|
+
use_fast: true,
|
97
|
+
token: nil,
|
98
|
+
device: nil,
|
99
|
+
device_map: nil,
|
100
|
+
torch_dtype: nil,
|
101
|
+
trust_remote_code: nil,
|
102
|
+
model_kwargs: nil,
|
103
|
+
pipeline_class: nil,
|
104
|
+
**kwargs
|
105
|
+
)
|
106
|
+
model_kwargs ||= {}
|
107
|
+
# Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
|
108
|
+
# this is to keep BC).
|
109
|
+
use_auth_token = model_kwargs.delete(:use_auth_token)
|
110
|
+
if !use_auth_token.nil?
|
111
|
+
raise Todo
|
112
|
+
end
|
113
|
+
|
114
|
+
code_revision = kwargs.delete(:code_revision)
|
115
|
+
commit_hash = kwargs.delete(:_commit_hash)
|
116
|
+
|
117
|
+
hub_kwargs = {
|
118
|
+
revision: revision,
|
119
|
+
token: token,
|
120
|
+
trust_remote_code: trust_remote_code,
|
121
|
+
_commit_hash: commit_hash
|
122
|
+
}
|
123
|
+
|
124
|
+
if task.nil? && model.nil?
|
125
|
+
raise RuntimeError,
|
126
|
+
"Impossible to instantiate a pipeline without either a task or a model " +
|
127
|
+
"being specified. " +
|
128
|
+
"Please provide a task class or a model"
|
129
|
+
end
|
130
|
+
|
131
|
+
if model.nil? && !tokenizer.nil?
|
132
|
+
raise RuntimeError,
|
133
|
+
"Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided tokenizer" +
|
134
|
+
" may not be compatible with the default model. Please provide a PreTrainedModel class or a" +
|
135
|
+
" path/identifier to a pretrained model when providing tokenizer."
|
136
|
+
end
|
137
|
+
if model.nil? && !feature_extractor.nil?
|
138
|
+
raise RuntimeError,
|
139
|
+
"Impossible to instantiate a pipeline with feature_extractor specified but not the model as the provided" +
|
140
|
+
" feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class" +
|
141
|
+
" or a path/identifier to a pretrained model when providing feature_extractor."
|
142
|
+
end
|
143
|
+
if model.is_a?(Pathname)
|
144
|
+
model = model.to_s
|
145
|
+
end
|
146
|
+
|
147
|
+
if commit_hash.nil?
|
148
|
+
pretrained_model_name_or_path = nil
|
149
|
+
if config.is_a?(String)
|
150
|
+
pretrained_model_name_or_path = config
|
151
|
+
elsif config.nil? && model.is_a?(String)
|
152
|
+
pretrained_model_name_or_path = model
|
153
|
+
end
|
154
|
+
|
155
|
+
if !config.is_a?(PretrainedConfig) && !pretrained_model_name_or_path.nil?
|
156
|
+
# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
|
157
|
+
resolved_config_file = Utils::Hub.cached_file(
|
158
|
+
pretrained_model_name_or_path,
|
159
|
+
CONFIG_NAME,
|
160
|
+
_raise_exceptions_for_gated_repo: false,
|
161
|
+
_raise_exceptions_for_missing_entries: false,
|
162
|
+
_raise_exceptions_for_connection_errors: false,
|
163
|
+
cache_dir: model_kwargs[:cache_dir],
|
164
|
+
**hub_kwargs
|
165
|
+
)
|
166
|
+
hub_kwargs[:_commit_hash] = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
|
167
|
+
else
|
168
|
+
hub_kwargs[:_commit_hash] = nil # getattr(config, "_commit_hash", None)
|
169
|
+
end
|
170
|
+
end
|
171
|
+
|
172
|
+
# Config is the primordial information item.
|
173
|
+
# Instantiate config if needed
|
174
|
+
if config.is_a?(String)
|
175
|
+
raise Todo
|
176
|
+
elsif config.nil? && model.is_a?(String)
|
177
|
+
config = AutoConfig.from_pretrained(
|
178
|
+
model, _from_pipeline: task, code_revision: code_revision, **hub_kwargs, **model_kwargs
|
179
|
+
)
|
180
|
+
hub_kwargs[:_commit_hash] = config._commit_hash
|
181
|
+
end
|
182
|
+
|
183
|
+
custom_tasks = {}
|
184
|
+
if !config.nil? && (config.instance_variable_get(:@custom_pipelines) || {}).length > 0
|
185
|
+
raise Todo
|
186
|
+
end
|
187
|
+
|
188
|
+
if task.nil? && !model.nil?
|
189
|
+
raise Todo
|
190
|
+
end
|
191
|
+
|
192
|
+
# Retrieve the task
|
193
|
+
if custom_tasks.include?(task)
|
194
|
+
raise Todo
|
195
|
+
else
|
196
|
+
_normalized_task, targeted_task, task_options = check_task(task)
|
197
|
+
if pipeline_class.nil?
|
198
|
+
pipeline_class = targeted_task["impl"]
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
# Use default model/config/tokenizer for the task if no model is provided
|
203
|
+
if model.nil?
|
204
|
+
# At that point framework might still be undetermined
|
205
|
+
model, default_revision = Pipelines.get_default_model_and_revision(targeted_task, framework, task_options)
|
206
|
+
revision = !revision.nil? ? revision : default_revision
|
207
|
+
Transformers.logger.warn(
|
208
|
+
"No model was supplied, defaulted to #{model} and revision" +
|
209
|
+
" #{revision} (#{Utils::Hub::HUGGINGFACE_CO_RESOLVE_ENDPOINT}/#{model}).\n" +
|
210
|
+
"Using a pipeline without specifying a model name and revision in production is not recommended."
|
211
|
+
)
|
212
|
+
if config.nil? && model.is_a?(String)
|
213
|
+
config = AutoConfig.from_pretrained(model, _from_pipeline: task, **hub_kwargs, **model_kwargs)
|
214
|
+
hub_kwargs[:_commit_hash] = config._commit_hash
|
215
|
+
end
|
216
|
+
end
|
217
|
+
|
218
|
+
if !device_map.nil?
|
219
|
+
raise Todo
|
220
|
+
end
|
221
|
+
if !torch_dtype.nil?
|
222
|
+
raise Todo
|
223
|
+
end
|
224
|
+
|
225
|
+
model_name = model.is_a?(String) ? model : nil
|
226
|
+
|
227
|
+
# Load the correct model if possible
|
228
|
+
# Infer the framework from the model if not already defined
|
229
|
+
if model.is_a?(String) || framework.nil?
|
230
|
+
model_classes = {"tf" => targeted_task["tf"], "pt" => targeted_task["pt"]}
|
231
|
+
framework, model =
|
232
|
+
Pipelines.infer_framework_load_model(
|
233
|
+
model,
|
234
|
+
config,
|
235
|
+
model_classes: model_classes,
|
236
|
+
framework: framework,
|
237
|
+
task: task,
|
238
|
+
**hub_kwargs,
|
239
|
+
**model_kwargs
|
240
|
+
)
|
241
|
+
end
|
242
|
+
|
243
|
+
model_config = model.config
|
244
|
+
hub_kwargs[:_commit_hash] = model.config._commit_hash
|
245
|
+
model_config_type = model_config.class.name.split("::").last
|
246
|
+
load_tokenizer = TOKENIZER_MAPPING.include?(model_config_type) || !model_config.tokenizer_class.nil?
|
247
|
+
load_feature_extractor = FEATURE_EXTRACTOR_MAPPING.include?(model_config_type) || !feature_extractor.nil?
|
248
|
+
load_image_processor = IMAGE_PROCESSOR_MAPPING.include?(model_config_type) || !image_processor.nil?
|
249
|
+
|
250
|
+
if load_tokenizer
|
251
|
+
# Try to infer tokenizer from model or config name (if provided as str)
|
252
|
+
if tokenizer.nil?
|
253
|
+
if model_name.is_a?(String)
|
254
|
+
tokenizer = model_name
|
255
|
+
elsif config.is_a?(String)
|
256
|
+
tokenizer = config
|
257
|
+
else
|
258
|
+
# Impossible to guess what is the right tokenizer here
|
259
|
+
raise "Impossible to guess which tokenizer to use. Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
|
260
|
+
end
|
261
|
+
end
|
262
|
+
|
263
|
+
# Instantiate tokenizer if needed
|
264
|
+
if tokenizer.is_a?(String) || tokenizer.is_a?(Array)
|
265
|
+
if tokenizer.is_a?(Array)
|
266
|
+
# For array we have [tokenizer name, {kwargs}]
|
267
|
+
use_fast = tokenizer[1].delete(:use_fast) { use_fast }
|
268
|
+
tokenizer_identifier = tokenizer[0]
|
269
|
+
tokenizer_kwargs = tokenizer[1]
|
270
|
+
else
|
271
|
+
tokenizer_identifier = tokenizer
|
272
|
+
tokenizer_kwargs = model_kwargs.dup
|
273
|
+
tokenizer_kwargs.delete(:torch_dtype)
|
274
|
+
end
|
275
|
+
|
276
|
+
tokenizer =
|
277
|
+
AutoTokenizer.from_pretrained(
|
278
|
+
tokenizer_identifier, use_fast: use_fast, _from_pipeline: task, **hub_kwargs, **tokenizer_kwargs
|
279
|
+
)
|
280
|
+
end
|
281
|
+
end
|
282
|
+
|
283
|
+
if load_image_processor
|
284
|
+
# Try to infer image processor from model or config name (if provided as str)
|
285
|
+
if image_processor.nil?
|
286
|
+
if model_name.is_a?(String)
|
287
|
+
image_processor = model_name
|
288
|
+
elsif config.is_a?(String)
|
289
|
+
image_processor = config
|
290
|
+
# Backward compatibility, as `feature_extractor` used to be the name
|
291
|
+
# for `ImageProcessor`.
|
292
|
+
elsif !feature_extractor.nil? && feature_extractor.is_a?(BaseImageProcessor)
|
293
|
+
image_processor = feature_extractor
|
294
|
+
else
|
295
|
+
# Impossible to guess what is the right image_processor here
|
296
|
+
raise RuntimeError,
|
297
|
+
"Impossible to guess which image processor to use. " +
|
298
|
+
"Please provide a PreTrainedImageProcessor class or a path/identifier " +
|
299
|
+
"to a pretrained image processor."
|
300
|
+
end
|
301
|
+
end
|
302
|
+
|
303
|
+
# Instantiate image_processor if needed
|
304
|
+
if image_processor.is_a?(String) || image_processor.is_a?(Array)
|
305
|
+
image_processor = AutoImageProcessor.from_pretrained(
|
306
|
+
image_processor, _from_pipeline: task, **hub_kwargs, **model_kwargs
|
307
|
+
)
|
308
|
+
end
|
309
|
+
end
|
310
|
+
|
311
|
+
if load_feature_extractor
|
312
|
+
raise Todo
|
313
|
+
end
|
314
|
+
|
315
|
+
if task == "translation" && model.config.task_specific_params
|
316
|
+
raise Todo
|
317
|
+
end
|
318
|
+
|
319
|
+
if !tokenizer.nil?
|
320
|
+
kwargs[:tokenizer] = tokenizer
|
321
|
+
end
|
322
|
+
|
323
|
+
if !feature_extractor.nil?
|
324
|
+
kwargs[:feature_extractor] = feature_extractor
|
325
|
+
end
|
326
|
+
|
327
|
+
if !torch_dtype.nil?
|
328
|
+
kwargs[:torch_dtype] = torch_dtype
|
329
|
+
end
|
330
|
+
|
331
|
+
if !image_processor.nil?
|
332
|
+
kwargs[:image_processor] = image_processor
|
333
|
+
end
|
334
|
+
|
335
|
+
if !device.nil?
|
336
|
+
kwargs[:device] = device
|
337
|
+
end
|
338
|
+
|
339
|
+
pipeline_class.new(model, framework: framework, task: task, **kwargs)
|
340
|
+
end
|
341
|
+
|
342
|
+
private
|
343
|
+
|
344
|
+
def check_task(task)
|
345
|
+
PIPELINE_REGISTRY.check_task(task)
|
346
|
+
end
|
347
|
+
end
|
348
|
+
end
|
@@ -0,0 +1,301 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module Pipelines
|
17
|
+
def self.get_default_model_and_revision(targeted_task, framework, task_options)
|
18
|
+
defaults = targeted_task["default"]
|
19
|
+
|
20
|
+
if defaults.key?("model")
|
21
|
+
default_models = targeted_task["default"]["model"]
|
22
|
+
end
|
23
|
+
|
24
|
+
if framework.nil?
|
25
|
+
framework = "pt"
|
26
|
+
end
|
27
|
+
|
28
|
+
default_models[framework]
|
29
|
+
end
|
30
|
+
|
31
|
+
def self.infer_framework_load_model(
|
32
|
+
model,
|
33
|
+
config,
|
34
|
+
model_classes: nil,
|
35
|
+
task: nil,
|
36
|
+
framework: nil,
|
37
|
+
**model_kwargs
|
38
|
+
)
|
39
|
+
if model.is_a?(String)
|
40
|
+
model_kwargs[:_from_pipeline] = task
|
41
|
+
class_tuple = []
|
42
|
+
look_pt = true
|
43
|
+
|
44
|
+
if model_classes
|
45
|
+
if look_pt
|
46
|
+
class_tuple = class_tuple + model_classes.fetch("pt", AutoModel)
|
47
|
+
end
|
48
|
+
end
|
49
|
+
if config.architectures
|
50
|
+
classes = []
|
51
|
+
config.architectures.each do |architecture|
|
52
|
+
if look_pt
|
53
|
+
_class = Transformers.const_get(architecture)
|
54
|
+
if !_class.nil?
|
55
|
+
classes << _class
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
class_tuple = class_tuple + classes
|
60
|
+
end
|
61
|
+
|
62
|
+
if class_tuple.length == 0
|
63
|
+
raise ArgumentError, "Pipeline cannot infer suitable model classes from #{model}"
|
64
|
+
end
|
65
|
+
|
66
|
+
class_tuple.each do |model_class|
|
67
|
+
raise Error, "Invalid auto model class: #{model_class}" unless model_class < BaseAutoModelClass
|
68
|
+
kwargs = model_kwargs.dup
|
69
|
+
|
70
|
+
begin
|
71
|
+
model = model_class.from_pretrained(model, **kwargs)
|
72
|
+
if model.respond_to?(:eval)
|
73
|
+
model = model.eval
|
74
|
+
end
|
75
|
+
break
|
76
|
+
rescue
|
77
|
+
# TODO
|
78
|
+
raise
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
if framework.nil?
|
84
|
+
framework = Utils.infer_framework(model.class)
|
85
|
+
end
|
86
|
+
[framework, model]
|
87
|
+
end
|
88
|
+
end
|
89
|
+
|
90
|
+
class ArgumentHandler
|
91
|
+
end
|
92
|
+
|
93
|
+
class Pipeline
|
94
|
+
def initialize(
|
95
|
+
model,
|
96
|
+
tokenizer: nil,
|
97
|
+
feature_extractor: nil,
|
98
|
+
image_processor: nil,
|
99
|
+
modelcard: nil,
|
100
|
+
framework: nil,
|
101
|
+
task: "",
|
102
|
+
device: nil,
|
103
|
+
**kwargs
|
104
|
+
)
|
105
|
+
if framework.nil?
|
106
|
+
raise Todo
|
107
|
+
end
|
108
|
+
|
109
|
+
@task = task
|
110
|
+
@model = model
|
111
|
+
@tokenizer = tokenizer
|
112
|
+
@feature_extractor = feature_extractor
|
113
|
+
@image_processor = image_processor
|
114
|
+
@modelcard = modelcard
|
115
|
+
@framework = framework
|
116
|
+
|
117
|
+
if device.nil?
|
118
|
+
if Torch::CUDA.available? || Torch::Backends::MPS.available?
|
119
|
+
Transformers.logger.warn(
|
120
|
+
"Hardware accelerator e.g. GPU is available in the environment, but no `device` argument" +
|
121
|
+
" is passed to the `Pipeline` object. Model will be on CPU."
|
122
|
+
)
|
123
|
+
end
|
124
|
+
end
|
125
|
+
|
126
|
+
@call_count = 0
|
127
|
+
@batch_size = kwargs.delete(:batch_size)
|
128
|
+
@num_workers = kwargs.delete(:num_workers)
|
129
|
+
@preprocess_params, @forward_params, @postprocess_params = _sanitize_parameters(**kwargs)
|
130
|
+
end
|
131
|
+
|
132
|
+
def torch_dtype
|
133
|
+
@model.dtype
|
134
|
+
end
|
135
|
+
|
136
|
+
def check_model_type(supported_models)
|
137
|
+
if !supported_models.is_a?(Array)
|
138
|
+
supported_models_names = []
|
139
|
+
supported_models.each do |_, model_name|
|
140
|
+
# Mapping can now contain tuples of models for the same configuration.
|
141
|
+
if model_name.is_a?(Array)
|
142
|
+
supported_models_names.concat(model_name)
|
143
|
+
else
|
144
|
+
supported_models_names << model_name
|
145
|
+
end
|
146
|
+
end
|
147
|
+
supported_models = supported_models_names
|
148
|
+
end
|
149
|
+
if !supported_models.include?(@model.class.name.split("::").last)
|
150
|
+
Transformers.logger.error(
|
151
|
+
"The model '#{@model.class.name}' is not supported for #{@task}. Supported models are" +
|
152
|
+
" #{supported_models}."
|
153
|
+
)
|
154
|
+
end
|
155
|
+
end
|
156
|
+
|
157
|
+
def get_iterator(
|
158
|
+
inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
|
159
|
+
)
|
160
|
+
if inputs.respond_to?(:size)
|
161
|
+
dataset = PipelineDataset.new(inputs, method(:preprocess), preprocess_params)
|
162
|
+
else
|
163
|
+
if num_workers > 1
|
164
|
+
Transformers.logger.warn(
|
165
|
+
"For iterable dataset using num_workers>1 is likely to result" +
|
166
|
+
" in errors since everything is iterable, setting `num_workers: 1`" +
|
167
|
+
" to guarantee correctness."
|
168
|
+
)
|
169
|
+
num_workers = 1
|
170
|
+
end
|
171
|
+
dataset = PipelineIterator.new(inputs, method(:preprocess), preprocess_params)
|
172
|
+
end
|
173
|
+
|
174
|
+
# TODO hack by collating feature_extractor and image_processor
|
175
|
+
feature_extractor = !@feature_extractor.nil? ? @feature_extractor : @image_processor
|
176
|
+
collate_fn = batch_size == 1 ? method(:no_collate_fn) : pad_collate_fn(@tokenizer, feature_extractor)
|
177
|
+
dataloader = Torch::Utils::Data::DataLoader.new(dataset, batch_size: batch_size, collate_fn: collate_fn) # num_workers: num_workers,
|
178
|
+
model_iterator = PipelineIterator.new(dataloader, method(:forward), forward_params, loader_batch_size: batch_size)
|
179
|
+
final_iterator = PipelineIterator.new(model_iterator, method(:postprocess), postprocess_params)
|
180
|
+
final_iterator
|
181
|
+
end
|
182
|
+
|
183
|
+
def call(inputs, *args, num_workers: nil, batch_size: nil, **kwargs)
|
184
|
+
if args.any?
|
185
|
+
Transformers.logger.warn("Ignoring args : #{args}")
|
186
|
+
end
|
187
|
+
|
188
|
+
if num_workers.nil?
|
189
|
+
if @num_workers.nil?
|
190
|
+
num_workers = 0
|
191
|
+
else
|
192
|
+
num_workers = @num_workers
|
193
|
+
end
|
194
|
+
end
|
195
|
+
if batch_size.nil?
|
196
|
+
if @batch_size.nil?
|
197
|
+
batch_size = 1
|
198
|
+
else
|
199
|
+
batch_size = @batch_size
|
200
|
+
end
|
201
|
+
end
|
202
|
+
|
203
|
+
preprocess_params, forward_params, postprocess_params = _sanitize_parameters(**kwargs)
|
204
|
+
|
205
|
+
preprocess_params = @preprocess_params.merge(preprocess_params)
|
206
|
+
forward_params = @forward_params.merge(forward_params)
|
207
|
+
postprocess_params = @postprocess_params.merge(postprocess_params)
|
208
|
+
|
209
|
+
@call_count += 1
|
210
|
+
if @call_count > 10 && @framework == "pt" && @device.type == "cuda"
|
211
|
+
Transformers.logger.warn(
|
212
|
+
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a" +
|
213
|
+
" dataset"
|
214
|
+
)
|
215
|
+
end
|
216
|
+
|
217
|
+
is_dataset = inputs.is_a?(Torch::Utils::Data::Dataset)
|
218
|
+
is_generator = inputs.is_a?(Enumerable)
|
219
|
+
is_list = inputs.is_a?(Array)
|
220
|
+
|
221
|
+
_is_iterable = is_dataset || is_generator || is_list
|
222
|
+
|
223
|
+
# TODO make the get_iterator work also for `tf` (and `flax`).
|
224
|
+
can_use_iterator = @framework == "pt" && (is_dataset || is_generator || is_list)
|
225
|
+
|
226
|
+
if is_list
|
227
|
+
if can_use_iterator
|
228
|
+
final_iterator = get_iterator(
|
229
|
+
inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
|
230
|
+
)
|
231
|
+
outputs = final_iterator.to_a
|
232
|
+
outputs
|
233
|
+
else
|
234
|
+
run_multi(inputs, preprocess_params, forward_params, postprocess_params)
|
235
|
+
end
|
236
|
+
else
|
237
|
+
run_single(inputs, preprocess_params, forward_params, postprocess_params)
|
238
|
+
end
|
239
|
+
end
|
240
|
+
|
241
|
+
private
|
242
|
+
|
243
|
+
def _sanitize_parameters(**kwargs)
|
244
|
+
raise NotImplementedError, "_sanitize_parameters not implemented"
|
245
|
+
end
|
246
|
+
|
247
|
+
def forward(model_inputs, **forward_params)
|
248
|
+
_forward(model_inputs, **forward_params)
|
249
|
+
end
|
250
|
+
|
251
|
+
def run_single(inputs, preprocess_params, forward_params, postprocess_params)
|
252
|
+
model_inputs = preprocess(inputs, **preprocess_params)
|
253
|
+
model_outputs = forward(model_inputs, **forward_params)
|
254
|
+
outputs = postprocess(model_outputs, **postprocess_params)
|
255
|
+
outputs
|
256
|
+
end
|
257
|
+
|
258
|
+
def no_collate_fn(items)
|
259
|
+
if items.length != 1
|
260
|
+
raise ArgumentError, "This collate_fn is meant to be used with batch_size=1"
|
261
|
+
end
|
262
|
+
items[0]
|
263
|
+
end
|
264
|
+
end
|
265
|
+
|
266
|
+
class ChunkPipeline < Pipeline
|
267
|
+
def run_single(inputs, preprocess_params, forward_params, postprocess_params)
|
268
|
+
all_outputs = []
|
269
|
+
preprocess(inputs, **preprocess_params) do |model_inputs|
|
270
|
+
model_outputs = forward(model_inputs, **forward_params)
|
271
|
+
all_outputs << model_outputs
|
272
|
+
end
|
273
|
+
outputs = postprocess(all_outputs, **postprocess_params)
|
274
|
+
outputs
|
275
|
+
end
|
276
|
+
end
|
277
|
+
|
278
|
+
class PipelineRegistry
|
279
|
+
def initialize(supported_tasks:, task_aliases:)
|
280
|
+
@supported_tasks = supported_tasks
|
281
|
+
@task_aliases = task_aliases
|
282
|
+
end
|
283
|
+
|
284
|
+
def get_supported_tasks
|
285
|
+
supported_task = @supported_tasks.keys + @task_aliases.keys
|
286
|
+
supported_task.sort
|
287
|
+
end
|
288
|
+
|
289
|
+
def check_task(task)
|
290
|
+
if @task_aliases[task]
|
291
|
+
task = @task_aliases[task]
|
292
|
+
end
|
293
|
+
if @supported_tasks[task]
|
294
|
+
targeted_task = @supported_tasks[task]
|
295
|
+
return task, targeted_task, nil
|
296
|
+
end
|
297
|
+
|
298
|
+
raise KeyError, "Unknown task #{task}, available tasks are #{get_supported_tasks}"
|
299
|
+
end
|
300
|
+
end
|
301
|
+
end
|
@@ -0,0 +1,47 @@
|
|
1
|
+
module Transformers
|
2
|
+
class FeatureExtractionPipeline < Pipeline
|
3
|
+
def _sanitize_parameters(truncation: nil, tokenize_kwargs: nil, return_tensors: nil, **kwargs)
|
4
|
+
if tokenize_kwargs.nil?
|
5
|
+
tokenize_kwargs = {}
|
6
|
+
end
|
7
|
+
|
8
|
+
if !truncation.nil?
|
9
|
+
if tokenize_kwargs.include?(:truncation)
|
10
|
+
raise ArgumentError,
|
11
|
+
"truncation parameter defined twice (given as keyword argument as well as in tokenize_kwargs)"
|
12
|
+
end
|
13
|
+
tokenize_kwargs[:truncation] = truncation
|
14
|
+
end
|
15
|
+
|
16
|
+
preprocess_params = tokenize_kwargs
|
17
|
+
|
18
|
+
postprocess_params = {}
|
19
|
+
if !return_tensors.nil?
|
20
|
+
postprocess_params[:return_tensors] = return_tensors
|
21
|
+
end
|
22
|
+
|
23
|
+
[preprocess_params, {}, postprocess_params]
|
24
|
+
end
|
25
|
+
|
26
|
+
def preprocess(inputs, **tokenize_kwargs)
|
27
|
+
model_inputs = @tokenizer.(inputs, return_tensors: @framework, **tokenize_kwargs)
|
28
|
+
model_inputs
|
29
|
+
end
|
30
|
+
|
31
|
+
def _forward(model_inputs)
|
32
|
+
model_outputs = @model.(**model_inputs)
|
33
|
+
model_outputs
|
34
|
+
end
|
35
|
+
|
36
|
+
def postprocess(model_outputs, return_tensors: false)
|
37
|
+
# [0] is the first available tensor, logits or last_hidden_state.
|
38
|
+
if return_tensors
|
39
|
+
model_outputs[0]
|
40
|
+
elsif @framework == "pt"
|
41
|
+
model_outputs[0].to_a
|
42
|
+
elsif @framework == "tf"
|
43
|
+
raise Todo
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|