transformers-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,348 @@
1
+ # Copyright 2018 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ TASK_ALIASES = {
17
+ "sentiment-analysis" => "text-classification",
18
+ "ner" => "token-classification"
19
+ }
20
+
21
+ SUPPORTED_TASKS = {
22
+ "feature-extraction" => {
23
+ "impl" => FeatureExtractionPipeline,
24
+ "pt" => [AutoModel],
25
+ "default" => {
26
+ "model" => {
27
+ "pt" => ["distilbert/distilbert-base-cased", "935ac13"]
28
+ }
29
+ },
30
+ "type" => "multimodal"
31
+ },
32
+ "text-classification" => {
33
+ "impl" => TextClassificationPipeline,
34
+ "pt" => [AutoModelForSequenceClassification],
35
+ "default" => {
36
+ "model" => {
37
+ "pt" => ["distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"]
38
+ }
39
+ },
40
+ "type" => "text"
41
+ },
42
+ "token-classification" => {
43
+ "impl" => TokenClassificationPipeline,
44
+ "pt" => [AutoModelForTokenClassification],
45
+ "default" => {
46
+ "model" => {
47
+ "pt" => ["dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"]
48
+ }
49
+ },
50
+ "type" => "text"
51
+ },
52
+ "question-answering" => {
53
+ "impl" => QuestionAnsweringPipeline,
54
+ "pt" => [AutoModelForQuestionAnswering],
55
+ "default" => {
56
+ "model" => {
57
+ "pt" => ["distilbert/distilbert-base-cased-distilled-squad", "626af31"]
58
+ }
59
+ },
60
+ "type" => "text"
61
+ },
62
+ "image-classification" => {
63
+ "impl" => ImageClassificationPipeline,
64
+ "pt" => [AutoModelForImageClassification],
65
+ "default" => {
66
+ "model" => {
67
+ "pt" => ["google/vit-base-patch16-224", "5dca96d"]
68
+ }
69
+ },
70
+ "type" => "image"
71
+ },
72
+ "image-feature-extraction" => {
73
+ "impl" => ImageFeatureExtractionPipeline,
74
+ "pt" => [AutoModel],
75
+ "default" => {
76
+ "model" => {
77
+ "pt" => ["google/vit-base-patch16-224", "3f49326"]
78
+ }
79
+ },
80
+ "type" => "image"
81
+ }
82
+ }
83
+
84
+ PIPELINE_REGISTRY = PipelineRegistry.new(supported_tasks: SUPPORTED_TASKS, task_aliases: TASK_ALIASES)
85
+
86
+ class << self
87
+ def pipeline(
88
+ task,
89
+ model: nil,
90
+ config: nil,
91
+ tokenizer: nil,
92
+ feature_extractor: nil,
93
+ image_processor: nil,
94
+ framework: nil,
95
+ revision: nil,
96
+ use_fast: true,
97
+ token: nil,
98
+ device: nil,
99
+ device_map: nil,
100
+ torch_dtype: nil,
101
+ trust_remote_code: nil,
102
+ model_kwargs: nil,
103
+ pipeline_class: nil,
104
+ **kwargs
105
+ )
106
+ model_kwargs ||= {}
107
+ # Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
108
+ # this is to keep BC).
109
+ use_auth_token = model_kwargs.delete(:use_auth_token)
110
+ if !use_auth_token.nil?
111
+ raise Todo
112
+ end
113
+
114
+ code_revision = kwargs.delete(:code_revision)
115
+ commit_hash = kwargs.delete(:_commit_hash)
116
+
117
+ hub_kwargs = {
118
+ revision: revision,
119
+ token: token,
120
+ trust_remote_code: trust_remote_code,
121
+ _commit_hash: commit_hash
122
+ }
123
+
124
+ if task.nil? && model.nil?
125
+ raise RuntimeError,
126
+ "Impossible to instantiate a pipeline without either a task or a model " +
127
+ "being specified. " +
128
+ "Please provide a task class or a model"
129
+ end
130
+
131
+ if model.nil? && !tokenizer.nil?
132
+ raise RuntimeError,
133
+ "Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided tokenizer" +
134
+ " may not be compatible with the default model. Please provide a PreTrainedModel class or a" +
135
+ " path/identifier to a pretrained model when providing tokenizer."
136
+ end
137
+ if model.nil? && !feature_extractor.nil?
138
+ raise RuntimeError,
139
+ "Impossible to instantiate a pipeline with feature_extractor specified but not the model as the provided" +
140
+ " feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class" +
141
+ " or a path/identifier to a pretrained model when providing feature_extractor."
142
+ end
143
+ if model.is_a?(Pathname)
144
+ model = model.to_s
145
+ end
146
+
147
+ if commit_hash.nil?
148
+ pretrained_model_name_or_path = nil
149
+ if config.is_a?(String)
150
+ pretrained_model_name_or_path = config
151
+ elsif config.nil? && model.is_a?(String)
152
+ pretrained_model_name_or_path = model
153
+ end
154
+
155
+ if !config.is_a?(PretrainedConfig) && !pretrained_model_name_or_path.nil?
156
+ # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
157
+ resolved_config_file = Utils::Hub.cached_file(
158
+ pretrained_model_name_or_path,
159
+ CONFIG_NAME,
160
+ _raise_exceptions_for_gated_repo: false,
161
+ _raise_exceptions_for_missing_entries: false,
162
+ _raise_exceptions_for_connection_errors: false,
163
+ cache_dir: model_kwargs[:cache_dir],
164
+ **hub_kwargs
165
+ )
166
+ hub_kwargs[:_commit_hash] = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
167
+ else
168
+ hub_kwargs[:_commit_hash] = nil # getattr(config, "_commit_hash", None)
169
+ end
170
+ end
171
+
172
+ # Config is the primordial information item.
173
+ # Instantiate config if needed
174
+ if config.is_a?(String)
175
+ raise Todo
176
+ elsif config.nil? && model.is_a?(String)
177
+ config = AutoConfig.from_pretrained(
178
+ model, _from_pipeline: task, code_revision: code_revision, **hub_kwargs, **model_kwargs
179
+ )
180
+ hub_kwargs[:_commit_hash] = config._commit_hash
181
+ end
182
+
183
+ custom_tasks = {}
184
+ if !config.nil? && (config.instance_variable_get(:@custom_pipelines) || {}).length > 0
185
+ raise Todo
186
+ end
187
+
188
+ if task.nil? && !model.nil?
189
+ raise Todo
190
+ end
191
+
192
+ # Retrieve the task
193
+ if custom_tasks.include?(task)
194
+ raise Todo
195
+ else
196
+ _normalized_task, targeted_task, task_options = check_task(task)
197
+ if pipeline_class.nil?
198
+ pipeline_class = targeted_task["impl"]
199
+ end
200
+ end
201
+
202
+ # Use default model/config/tokenizer for the task if no model is provided
203
+ if model.nil?
204
+ # At that point framework might still be undetermined
205
+ model, default_revision = Pipelines.get_default_model_and_revision(targeted_task, framework, task_options)
206
+ revision = !revision.nil? ? revision : default_revision
207
+ Transformers.logger.warn(
208
+ "No model was supplied, defaulted to #{model} and revision" +
209
+ " #{revision} (#{Utils::Hub::HUGGINGFACE_CO_RESOLVE_ENDPOINT}/#{model}).\n" +
210
+ "Using a pipeline without specifying a model name and revision in production is not recommended."
211
+ )
212
+ if config.nil? && model.is_a?(String)
213
+ config = AutoConfig.from_pretrained(model, _from_pipeline: task, **hub_kwargs, **model_kwargs)
214
+ hub_kwargs[:_commit_hash] = config._commit_hash
215
+ end
216
+ end
217
+
218
+ if !device_map.nil?
219
+ raise Todo
220
+ end
221
+ if !torch_dtype.nil?
222
+ raise Todo
223
+ end
224
+
225
+ model_name = model.is_a?(String) ? model : nil
226
+
227
+ # Load the correct model if possible
228
+ # Infer the framework from the model if not already defined
229
+ if model.is_a?(String) || framework.nil?
230
+ model_classes = {"tf" => targeted_task["tf"], "pt" => targeted_task["pt"]}
231
+ framework, model =
232
+ Pipelines.infer_framework_load_model(
233
+ model,
234
+ config,
235
+ model_classes: model_classes,
236
+ framework: framework,
237
+ task: task,
238
+ **hub_kwargs,
239
+ **model_kwargs
240
+ )
241
+ end
242
+
243
+ model_config = model.config
244
+ hub_kwargs[:_commit_hash] = model.config._commit_hash
245
+ model_config_type = model_config.class.name.split("::").last
246
+ load_tokenizer = TOKENIZER_MAPPING.include?(model_config_type) || !model_config.tokenizer_class.nil?
247
+ load_feature_extractor = FEATURE_EXTRACTOR_MAPPING.include?(model_config_type) || !feature_extractor.nil?
248
+ load_image_processor = IMAGE_PROCESSOR_MAPPING.include?(model_config_type) || !image_processor.nil?
249
+
250
+ if load_tokenizer
251
+ # Try to infer tokenizer from model or config name (if provided as str)
252
+ if tokenizer.nil?
253
+ if model_name.is_a?(String)
254
+ tokenizer = model_name
255
+ elsif config.is_a?(String)
256
+ tokenizer = config
257
+ else
258
+ # Impossible to guess what is the right tokenizer here
259
+ raise "Impossible to guess which tokenizer to use. Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
260
+ end
261
+ end
262
+
263
+ # Instantiate tokenizer if needed
264
+ if tokenizer.is_a?(String) || tokenizer.is_a?(Array)
265
+ if tokenizer.is_a?(Array)
266
+ # For array we have [tokenizer name, {kwargs}]
267
+ use_fast = tokenizer[1].delete(:use_fast) { use_fast }
268
+ tokenizer_identifier = tokenizer[0]
269
+ tokenizer_kwargs = tokenizer[1]
270
+ else
271
+ tokenizer_identifier = tokenizer
272
+ tokenizer_kwargs = model_kwargs.dup
273
+ tokenizer_kwargs.delete(:torch_dtype)
274
+ end
275
+
276
+ tokenizer =
277
+ AutoTokenizer.from_pretrained(
278
+ tokenizer_identifier, use_fast: use_fast, _from_pipeline: task, **hub_kwargs, **tokenizer_kwargs
279
+ )
280
+ end
281
+ end
282
+
283
+ if load_image_processor
284
+ # Try to infer image processor from model or config name (if provided as str)
285
+ if image_processor.nil?
286
+ if model_name.is_a?(String)
287
+ image_processor = model_name
288
+ elsif config.is_a?(String)
289
+ image_processor = config
290
+ # Backward compatibility, as `feature_extractor` used to be the name
291
+ # for `ImageProcessor`.
292
+ elsif !feature_extractor.nil? && feature_extractor.is_a?(BaseImageProcessor)
293
+ image_processor = feature_extractor
294
+ else
295
+ # Impossible to guess what is the right image_processor here
296
+ raise RuntimeError,
297
+ "Impossible to guess which image processor to use. " +
298
+ "Please provide a PreTrainedImageProcessor class or a path/identifier " +
299
+ "to a pretrained image processor."
300
+ end
301
+ end
302
+
303
+ # Instantiate image_processor if needed
304
+ if image_processor.is_a?(String) || image_processor.is_a?(Array)
305
+ image_processor = AutoImageProcessor.from_pretrained(
306
+ image_processor, _from_pipeline: task, **hub_kwargs, **model_kwargs
307
+ )
308
+ end
309
+ end
310
+
311
+ if load_feature_extractor
312
+ raise Todo
313
+ end
314
+
315
+ if task == "translation" && model.config.task_specific_params
316
+ raise Todo
317
+ end
318
+
319
+ if !tokenizer.nil?
320
+ kwargs[:tokenizer] = tokenizer
321
+ end
322
+
323
+ if !feature_extractor.nil?
324
+ kwargs[:feature_extractor] = feature_extractor
325
+ end
326
+
327
+ if !torch_dtype.nil?
328
+ kwargs[:torch_dtype] = torch_dtype
329
+ end
330
+
331
+ if !image_processor.nil?
332
+ kwargs[:image_processor] = image_processor
333
+ end
334
+
335
+ if !device.nil?
336
+ kwargs[:device] = device
337
+ end
338
+
339
+ pipeline_class.new(model, framework: framework, task: task, **kwargs)
340
+ end
341
+
342
+ private
343
+
344
+ def check_task(task)
345
+ PIPELINE_REGISTRY.check_task(task)
346
+ end
347
+ end
348
+ end
@@ -0,0 +1,301 @@
1
+ # Copyright 2018 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module Pipelines
17
+ def self.get_default_model_and_revision(targeted_task, framework, task_options)
18
+ defaults = targeted_task["default"]
19
+
20
+ if defaults.key?("model")
21
+ default_models = targeted_task["default"]["model"]
22
+ end
23
+
24
+ if framework.nil?
25
+ framework = "pt"
26
+ end
27
+
28
+ default_models[framework]
29
+ end
30
+
31
+ def self.infer_framework_load_model(
32
+ model,
33
+ config,
34
+ model_classes: nil,
35
+ task: nil,
36
+ framework: nil,
37
+ **model_kwargs
38
+ )
39
+ if model.is_a?(String)
40
+ model_kwargs[:_from_pipeline] = task
41
+ class_tuple = []
42
+ look_pt = true
43
+
44
+ if model_classes
45
+ if look_pt
46
+ class_tuple = class_tuple + model_classes.fetch("pt", AutoModel)
47
+ end
48
+ end
49
+ if config.architectures
50
+ classes = []
51
+ config.architectures.each do |architecture|
52
+ if look_pt
53
+ _class = Transformers.const_get(architecture)
54
+ if !_class.nil?
55
+ classes << _class
56
+ end
57
+ end
58
+ end
59
+ class_tuple = class_tuple + classes
60
+ end
61
+
62
+ if class_tuple.length == 0
63
+ raise ArgumentError, "Pipeline cannot infer suitable model classes from #{model}"
64
+ end
65
+
66
+ class_tuple.each do |model_class|
67
+ raise Error, "Invalid auto model class: #{model_class}" unless model_class < BaseAutoModelClass
68
+ kwargs = model_kwargs.dup
69
+
70
+ begin
71
+ model = model_class.from_pretrained(model, **kwargs)
72
+ if model.respond_to?(:eval)
73
+ model = model.eval
74
+ end
75
+ break
76
+ rescue
77
+ # TODO
78
+ raise
79
+ end
80
+ end
81
+ end
82
+
83
+ if framework.nil?
84
+ framework = Utils.infer_framework(model.class)
85
+ end
86
+ [framework, model]
87
+ end
88
+ end
89
+
90
+ class ArgumentHandler
91
+ end
92
+
93
+ class Pipeline
94
+ def initialize(
95
+ model,
96
+ tokenizer: nil,
97
+ feature_extractor: nil,
98
+ image_processor: nil,
99
+ modelcard: nil,
100
+ framework: nil,
101
+ task: "",
102
+ device: nil,
103
+ **kwargs
104
+ )
105
+ if framework.nil?
106
+ raise Todo
107
+ end
108
+
109
+ @task = task
110
+ @model = model
111
+ @tokenizer = tokenizer
112
+ @feature_extractor = feature_extractor
113
+ @image_processor = image_processor
114
+ @modelcard = modelcard
115
+ @framework = framework
116
+
117
+ if device.nil?
118
+ if Torch::CUDA.available? || Torch::Backends::MPS.available?
119
+ Transformers.logger.warn(
120
+ "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument" +
121
+ " is passed to the `Pipeline` object. Model will be on CPU."
122
+ )
123
+ end
124
+ end
125
+
126
+ @call_count = 0
127
+ @batch_size = kwargs.delete(:batch_size)
128
+ @num_workers = kwargs.delete(:num_workers)
129
+ @preprocess_params, @forward_params, @postprocess_params = _sanitize_parameters(**kwargs)
130
+ end
131
+
132
+ def torch_dtype
133
+ @model.dtype
134
+ end
135
+
136
+ def check_model_type(supported_models)
137
+ if !supported_models.is_a?(Array)
138
+ supported_models_names = []
139
+ supported_models.each do |_, model_name|
140
+ # Mapping can now contain tuples of models for the same configuration.
141
+ if model_name.is_a?(Array)
142
+ supported_models_names.concat(model_name)
143
+ else
144
+ supported_models_names << model_name
145
+ end
146
+ end
147
+ supported_models = supported_models_names
148
+ end
149
+ if !supported_models.include?(@model.class.name.split("::").last)
150
+ Transformers.logger.error(
151
+ "The model '#{@model.class.name}' is not supported for #{@task}. Supported models are" +
152
+ " #{supported_models}."
153
+ )
154
+ end
155
+ end
156
+
157
+ def get_iterator(
158
+ inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
159
+ )
160
+ if inputs.respond_to?(:size)
161
+ dataset = PipelineDataset.new(inputs, method(:preprocess), preprocess_params)
162
+ else
163
+ if num_workers > 1
164
+ Transformers.logger.warn(
165
+ "For iterable dataset using num_workers>1 is likely to result" +
166
+ " in errors since everything is iterable, setting `num_workers: 1`" +
167
+ " to guarantee correctness."
168
+ )
169
+ num_workers = 1
170
+ end
171
+ dataset = PipelineIterator.new(inputs, method(:preprocess), preprocess_params)
172
+ end
173
+
174
+ # TODO hack by collating feature_extractor and image_processor
175
+ feature_extractor = !@feature_extractor.nil? ? @feature_extractor : @image_processor
176
+ collate_fn = batch_size == 1 ? method(:no_collate_fn) : pad_collate_fn(@tokenizer, feature_extractor)
177
+ dataloader = Torch::Utils::Data::DataLoader.new(dataset, batch_size: batch_size, collate_fn: collate_fn) # num_workers: num_workers,
178
+ model_iterator = PipelineIterator.new(dataloader, method(:forward), forward_params, loader_batch_size: batch_size)
179
+ final_iterator = PipelineIterator.new(model_iterator, method(:postprocess), postprocess_params)
180
+ final_iterator
181
+ end
182
+
183
+ def call(inputs, *args, num_workers: nil, batch_size: nil, **kwargs)
184
+ if args.any?
185
+ Transformers.logger.warn("Ignoring args : #{args}")
186
+ end
187
+
188
+ if num_workers.nil?
189
+ if @num_workers.nil?
190
+ num_workers = 0
191
+ else
192
+ num_workers = @num_workers
193
+ end
194
+ end
195
+ if batch_size.nil?
196
+ if @batch_size.nil?
197
+ batch_size = 1
198
+ else
199
+ batch_size = @batch_size
200
+ end
201
+ end
202
+
203
+ preprocess_params, forward_params, postprocess_params = _sanitize_parameters(**kwargs)
204
+
205
+ preprocess_params = @preprocess_params.merge(preprocess_params)
206
+ forward_params = @forward_params.merge(forward_params)
207
+ postprocess_params = @postprocess_params.merge(postprocess_params)
208
+
209
+ @call_count += 1
210
+ if @call_count > 10 && @framework == "pt" && @device.type == "cuda"
211
+ Transformers.logger.warn(
212
+ "You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a" +
213
+ " dataset"
214
+ )
215
+ end
216
+
217
+ is_dataset = inputs.is_a?(Torch::Utils::Data::Dataset)
218
+ is_generator = inputs.is_a?(Enumerable)
219
+ is_list = inputs.is_a?(Array)
220
+
221
+ _is_iterable = is_dataset || is_generator || is_list
222
+
223
+ # TODO make the get_iterator work also for `tf` (and `flax`).
224
+ can_use_iterator = @framework == "pt" && (is_dataset || is_generator || is_list)
225
+
226
+ if is_list
227
+ if can_use_iterator
228
+ final_iterator = get_iterator(
229
+ inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
230
+ )
231
+ outputs = final_iterator.to_a
232
+ outputs
233
+ else
234
+ run_multi(inputs, preprocess_params, forward_params, postprocess_params)
235
+ end
236
+ else
237
+ run_single(inputs, preprocess_params, forward_params, postprocess_params)
238
+ end
239
+ end
240
+
241
+ private
242
+
243
+ def _sanitize_parameters(**kwargs)
244
+ raise NotImplementedError, "_sanitize_parameters not implemented"
245
+ end
246
+
247
+ def forward(model_inputs, **forward_params)
248
+ _forward(model_inputs, **forward_params)
249
+ end
250
+
251
+ def run_single(inputs, preprocess_params, forward_params, postprocess_params)
252
+ model_inputs = preprocess(inputs, **preprocess_params)
253
+ model_outputs = forward(model_inputs, **forward_params)
254
+ outputs = postprocess(model_outputs, **postprocess_params)
255
+ outputs
256
+ end
257
+
258
+ def no_collate_fn(items)
259
+ if items.length != 1
260
+ raise ArgumentError, "This collate_fn is meant to be used with batch_size=1"
261
+ end
262
+ items[0]
263
+ end
264
+ end
265
+
266
+ class ChunkPipeline < Pipeline
267
+ def run_single(inputs, preprocess_params, forward_params, postprocess_params)
268
+ all_outputs = []
269
+ preprocess(inputs, **preprocess_params) do |model_inputs|
270
+ model_outputs = forward(model_inputs, **forward_params)
271
+ all_outputs << model_outputs
272
+ end
273
+ outputs = postprocess(all_outputs, **postprocess_params)
274
+ outputs
275
+ end
276
+ end
277
+
278
+ class PipelineRegistry
279
+ def initialize(supported_tasks:, task_aliases:)
280
+ @supported_tasks = supported_tasks
281
+ @task_aliases = task_aliases
282
+ end
283
+
284
+ def get_supported_tasks
285
+ supported_task = @supported_tasks.keys + @task_aliases.keys
286
+ supported_task.sort
287
+ end
288
+
289
+ def check_task(task)
290
+ if @task_aliases[task]
291
+ task = @task_aliases[task]
292
+ end
293
+ if @supported_tasks[task]
294
+ targeted_task = @supported_tasks[task]
295
+ return task, targeted_task, nil
296
+ end
297
+
298
+ raise KeyError, "Unknown task #{task}, available tasks are #{get_supported_tasks}"
299
+ end
300
+ end
301
+ end
@@ -0,0 +1,47 @@
1
+ module Transformers
2
+ class FeatureExtractionPipeline < Pipeline
3
+ def _sanitize_parameters(truncation: nil, tokenize_kwargs: nil, return_tensors: nil, **kwargs)
4
+ if tokenize_kwargs.nil?
5
+ tokenize_kwargs = {}
6
+ end
7
+
8
+ if !truncation.nil?
9
+ if tokenize_kwargs.include?(:truncation)
10
+ raise ArgumentError,
11
+ "truncation parameter defined twice (given as keyword argument as well as in tokenize_kwargs)"
12
+ end
13
+ tokenize_kwargs[:truncation] = truncation
14
+ end
15
+
16
+ preprocess_params = tokenize_kwargs
17
+
18
+ postprocess_params = {}
19
+ if !return_tensors.nil?
20
+ postprocess_params[:return_tensors] = return_tensors
21
+ end
22
+
23
+ [preprocess_params, {}, postprocess_params]
24
+ end
25
+
26
+ def preprocess(inputs, **tokenize_kwargs)
27
+ model_inputs = @tokenizer.(inputs, return_tensors: @framework, **tokenize_kwargs)
28
+ model_inputs
29
+ end
30
+
31
+ def _forward(model_inputs)
32
+ model_outputs = @model.(**model_inputs)
33
+ model_outputs
34
+ end
35
+
36
+ def postprocess(model_outputs, return_tensors: false)
37
+ # [0] is the first available tensor, logits or last_hidden_state.
38
+ if return_tensors
39
+ model_outputs[0]
40
+ elsif @framework == "pt"
41
+ model_outputs[0].to_a
42
+ elsif @framework == "tf"
43
+ raise Todo
44
+ end
45
+ end
46
+ end
47
+ end