transformers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,348 @@
1
+ # Copyright 2018 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ TASK_ALIASES = {
17
+ "sentiment-analysis" => "text-classification",
18
+ "ner" => "token-classification"
19
+ }
20
+
21
+ SUPPORTED_TASKS = {
22
+ "feature-extraction" => {
23
+ "impl" => FeatureExtractionPipeline,
24
+ "pt" => [AutoModel],
25
+ "default" => {
26
+ "model" => {
27
+ "pt" => ["distilbert/distilbert-base-cased", "935ac13"]
28
+ }
29
+ },
30
+ "type" => "multimodal"
31
+ },
32
+ "text-classification" => {
33
+ "impl" => TextClassificationPipeline,
34
+ "pt" => [AutoModelForSequenceClassification],
35
+ "default" => {
36
+ "model" => {
37
+ "pt" => ["distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"]
38
+ }
39
+ },
40
+ "type" => "text"
41
+ },
42
+ "token-classification" => {
43
+ "impl" => TokenClassificationPipeline,
44
+ "pt" => [AutoModelForTokenClassification],
45
+ "default" => {
46
+ "model" => {
47
+ "pt" => ["dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"]
48
+ }
49
+ },
50
+ "type" => "text"
51
+ },
52
+ "question-answering" => {
53
+ "impl" => QuestionAnsweringPipeline,
54
+ "pt" => [AutoModelForQuestionAnswering],
55
+ "default" => {
56
+ "model" => {
57
+ "pt" => ["distilbert/distilbert-base-cased-distilled-squad", "626af31"]
58
+ }
59
+ },
60
+ "type" => "text"
61
+ },
62
+ "image-classification" => {
63
+ "impl" => ImageClassificationPipeline,
64
+ "pt" => [AutoModelForImageClassification],
65
+ "default" => {
66
+ "model" => {
67
+ "pt" => ["google/vit-base-patch16-224", "5dca96d"]
68
+ }
69
+ },
70
+ "type" => "image"
71
+ },
72
+ "image-feature-extraction" => {
73
+ "impl" => ImageFeatureExtractionPipeline,
74
+ "pt" => [AutoModel],
75
+ "default" => {
76
+ "model" => {
77
+ "pt" => ["google/vit-base-patch16-224", "3f49326"]
78
+ }
79
+ },
80
+ "type" => "image"
81
+ }
82
+ }
83
+
84
+ PIPELINE_REGISTRY = PipelineRegistry.new(supported_tasks: SUPPORTED_TASKS, task_aliases: TASK_ALIASES)
85
+
86
+ class << self
87
+ def pipeline(
88
+ task,
89
+ model: nil,
90
+ config: nil,
91
+ tokenizer: nil,
92
+ feature_extractor: nil,
93
+ image_processor: nil,
94
+ framework: nil,
95
+ revision: nil,
96
+ use_fast: true,
97
+ token: nil,
98
+ device: nil,
99
+ device_map: nil,
100
+ torch_dtype: nil,
101
+ trust_remote_code: nil,
102
+ model_kwargs: nil,
103
+ pipeline_class: nil,
104
+ **kwargs
105
+ )
106
+ model_kwargs ||= {}
107
+ # Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
108
+ # this is to keep BC).
109
+ use_auth_token = model_kwargs.delete(:use_auth_token)
110
+ if !use_auth_token.nil?
111
+ raise Todo
112
+ end
113
+
114
+ code_revision = kwargs.delete(:code_revision)
115
+ commit_hash = kwargs.delete(:_commit_hash)
116
+
117
+ hub_kwargs = {
118
+ revision: revision,
119
+ token: token,
120
+ trust_remote_code: trust_remote_code,
121
+ _commit_hash: commit_hash
122
+ }
123
+
124
+ if task.nil? && model.nil?
125
+ raise RuntimeError,
126
+ "Impossible to instantiate a pipeline without either a task or a model " +
127
+ "being specified. " +
128
+ "Please provide a task class or a model"
129
+ end
130
+
131
+ if model.nil? && !tokenizer.nil?
132
+ raise RuntimeError,
133
+ "Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided tokenizer" +
134
+ " may not be compatible with the default model. Please provide a PreTrainedModel class or a" +
135
+ " path/identifier to a pretrained model when providing tokenizer."
136
+ end
137
+ if model.nil? && !feature_extractor.nil?
138
+ raise RuntimeError,
139
+ "Impossible to instantiate a pipeline with feature_extractor specified but not the model as the provided" +
140
+ " feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class" +
141
+ " or a path/identifier to a pretrained model when providing feature_extractor."
142
+ end
143
+ if model.is_a?(Pathname)
144
+ model = model.to_s
145
+ end
146
+
147
+ if commit_hash.nil?
148
+ pretrained_model_name_or_path = nil
149
+ if config.is_a?(String)
150
+ pretrained_model_name_or_path = config
151
+ elsif config.nil? && model.is_a?(String)
152
+ pretrained_model_name_or_path = model
153
+ end
154
+
155
+ if !config.is_a?(PretrainedConfig) && !pretrained_model_name_or_path.nil?
156
+ # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
157
+ resolved_config_file = Utils::Hub.cached_file(
158
+ pretrained_model_name_or_path,
159
+ CONFIG_NAME,
160
+ _raise_exceptions_for_gated_repo: false,
161
+ _raise_exceptions_for_missing_entries: false,
162
+ _raise_exceptions_for_connection_errors: false,
163
+ cache_dir: model_kwargs[:cache_dir],
164
+ **hub_kwargs
165
+ )
166
+ hub_kwargs[:_commit_hash] = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
167
+ else
168
+ hub_kwargs[:_commit_hash] = nil # getattr(config, "_commit_hash", None)
169
+ end
170
+ end
171
+
172
+ # Config is the primordial information item.
173
+ # Instantiate config if needed
174
+ if config.is_a?(String)
175
+ raise Todo
176
+ elsif config.nil? && model.is_a?(String)
177
+ config = AutoConfig.from_pretrained(
178
+ model, _from_pipeline: task, code_revision: code_revision, **hub_kwargs, **model_kwargs
179
+ )
180
+ hub_kwargs[:_commit_hash] = config._commit_hash
181
+ end
182
+
183
+ custom_tasks = {}
184
+ if !config.nil? && (config.instance_variable_get(:@custom_pipelines) || {}).length > 0
185
+ raise Todo
186
+ end
187
+
188
+ if task.nil? && !model.nil?
189
+ raise Todo
190
+ end
191
+
192
+ # Retrieve the task
193
+ if custom_tasks.include?(task)
194
+ raise Todo
195
+ else
196
+ _normalized_task, targeted_task, task_options = check_task(task)
197
+ if pipeline_class.nil?
198
+ pipeline_class = targeted_task["impl"]
199
+ end
200
+ end
201
+
202
+ # Use default model/config/tokenizer for the task if no model is provided
203
+ if model.nil?
204
+ # At that point framework might still be undetermined
205
+ model, default_revision = Pipelines.get_default_model_and_revision(targeted_task, framework, task_options)
206
+ revision = !revision.nil? ? revision : default_revision
207
+ Transformers.logger.warn(
208
+ "No model was supplied, defaulted to #{model} and revision" +
209
+ " #{revision} (#{Utils::Hub::HUGGINGFACE_CO_RESOLVE_ENDPOINT}/#{model}).\n" +
210
+ "Using a pipeline without specifying a model name and revision in production is not recommended."
211
+ )
212
+ if config.nil? && model.is_a?(String)
213
+ config = AutoConfig.from_pretrained(model, _from_pipeline: task, **hub_kwargs, **model_kwargs)
214
+ hub_kwargs[:_commit_hash] = config._commit_hash
215
+ end
216
+ end
217
+
218
+ if !device_map.nil?
219
+ raise Todo
220
+ end
221
+ if !torch_dtype.nil?
222
+ raise Todo
223
+ end
224
+
225
+ model_name = model.is_a?(String) ? model : nil
226
+
227
+ # Load the correct model if possible
228
+ # Infer the framework from the model if not already defined
229
+ if model.is_a?(String) || framework.nil?
230
+ model_classes = {"tf" => targeted_task["tf"], "pt" => targeted_task["pt"]}
231
+ framework, model =
232
+ Pipelines.infer_framework_load_model(
233
+ model,
234
+ config,
235
+ model_classes: model_classes,
236
+ framework: framework,
237
+ task: task,
238
+ **hub_kwargs,
239
+ **model_kwargs
240
+ )
241
+ end
242
+
243
+ model_config = model.config
244
+ hub_kwargs[:_commit_hash] = model.config._commit_hash
245
+ model_config_type = model_config.class.name.split("::").last
246
+ load_tokenizer = TOKENIZER_MAPPING.include?(model_config_type) || !model_config.tokenizer_class.nil?
247
+ load_feature_extractor = FEATURE_EXTRACTOR_MAPPING.include?(model_config_type) || !feature_extractor.nil?
248
+ load_image_processor = IMAGE_PROCESSOR_MAPPING.include?(model_config_type) || !image_processor.nil?
249
+
250
+ if load_tokenizer
251
+ # Try to infer tokenizer from model or config name (if provided as str)
252
+ if tokenizer.nil?
253
+ if model_name.is_a?(String)
254
+ tokenizer = model_name
255
+ elsif config.is_a?(String)
256
+ tokenizer = config
257
+ else
258
+ # Impossible to guess what is the right tokenizer here
259
+ raise "Impossible to guess which tokenizer to use. Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer."
260
+ end
261
+ end
262
+
263
+ # Instantiate tokenizer if needed
264
+ if tokenizer.is_a?(String) || tokenizer.is_a?(Array)
265
+ if tokenizer.is_a?(Array)
266
+ # For array we have [tokenizer name, {kwargs}]
267
+ use_fast = tokenizer[1].delete(:use_fast) { use_fast }
268
+ tokenizer_identifier = tokenizer[0]
269
+ tokenizer_kwargs = tokenizer[1]
270
+ else
271
+ tokenizer_identifier = tokenizer
272
+ tokenizer_kwargs = model_kwargs.dup
273
+ tokenizer_kwargs.delete(:torch_dtype)
274
+ end
275
+
276
+ tokenizer =
277
+ AutoTokenizer.from_pretrained(
278
+ tokenizer_identifier, use_fast: use_fast, _from_pipeline: task, **hub_kwargs, **tokenizer_kwargs
279
+ )
280
+ end
281
+ end
282
+
283
+ if load_image_processor
284
+ # Try to infer image processor from model or config name (if provided as str)
285
+ if image_processor.nil?
286
+ if model_name.is_a?(String)
287
+ image_processor = model_name
288
+ elsif config.is_a?(String)
289
+ image_processor = config
290
+ # Backward compatibility, as `feature_extractor` used to be the name
291
+ # for `ImageProcessor`.
292
+ elsif !feature_extractor.nil? && feature_extractor.is_a?(BaseImageProcessor)
293
+ image_processor = feature_extractor
294
+ else
295
+ # Impossible to guess what is the right image_processor here
296
+ raise RuntimeError,
297
+ "Impossible to guess which image processor to use. " +
298
+ "Please provide a PreTrainedImageProcessor class or a path/identifier " +
299
+ "to a pretrained image processor."
300
+ end
301
+ end
302
+
303
+ # Instantiate image_processor if needed
304
+ if image_processor.is_a?(String) || image_processor.is_a?(Array)
305
+ image_processor = AutoImageProcessor.from_pretrained(
306
+ image_processor, _from_pipeline: task, **hub_kwargs, **model_kwargs
307
+ )
308
+ end
309
+ end
310
+
311
+ if load_feature_extractor
312
+ raise Todo
313
+ end
314
+
315
+ if task == "translation" && model.config.task_specific_params
316
+ raise Todo
317
+ end
318
+
319
+ if !tokenizer.nil?
320
+ kwargs[:tokenizer] = tokenizer
321
+ end
322
+
323
+ if !feature_extractor.nil?
324
+ kwargs[:feature_extractor] = feature_extractor
325
+ end
326
+
327
+ if !torch_dtype.nil?
328
+ kwargs[:torch_dtype] = torch_dtype
329
+ end
330
+
331
+ if !image_processor.nil?
332
+ kwargs[:image_processor] = image_processor
333
+ end
334
+
335
+ if !device.nil?
336
+ kwargs[:device] = device
337
+ end
338
+
339
+ pipeline_class.new(model, framework: framework, task: task, **kwargs)
340
+ end
341
+
342
+ private
343
+
344
+ def check_task(task)
345
+ PIPELINE_REGISTRY.check_task(task)
346
+ end
347
+ end
348
+ end
@@ -0,0 +1,301 @@
1
+ # Copyright 2018 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module Pipelines
17
+ def self.get_default_model_and_revision(targeted_task, framework, task_options)
18
+ defaults = targeted_task["default"]
19
+
20
+ if defaults.key?("model")
21
+ default_models = targeted_task["default"]["model"]
22
+ end
23
+
24
+ if framework.nil?
25
+ framework = "pt"
26
+ end
27
+
28
+ default_models[framework]
29
+ end
30
+
31
+ def self.infer_framework_load_model(
32
+ model,
33
+ config,
34
+ model_classes: nil,
35
+ task: nil,
36
+ framework: nil,
37
+ **model_kwargs
38
+ )
39
+ if model.is_a?(String)
40
+ model_kwargs[:_from_pipeline] = task
41
+ class_tuple = []
42
+ look_pt = true
43
+
44
+ if model_classes
45
+ if look_pt
46
+ class_tuple = class_tuple + model_classes.fetch("pt", AutoModel)
47
+ end
48
+ end
49
+ if config.architectures
50
+ classes = []
51
+ config.architectures.each do |architecture|
52
+ if look_pt
53
+ _class = Transformers.const_get(architecture)
54
+ if !_class.nil?
55
+ classes << _class
56
+ end
57
+ end
58
+ end
59
+ class_tuple = class_tuple + classes
60
+ end
61
+
62
+ if class_tuple.length == 0
63
+ raise ArgumentError, "Pipeline cannot infer suitable model classes from #{model}"
64
+ end
65
+
66
+ class_tuple.each do |model_class|
67
+ raise Error, "Invalid auto model class: #{model_class}" unless model_class < BaseAutoModelClass
68
+ kwargs = model_kwargs.dup
69
+
70
+ begin
71
+ model = model_class.from_pretrained(model, **kwargs)
72
+ if model.respond_to?(:eval)
73
+ model = model.eval
74
+ end
75
+ break
76
+ rescue
77
+ # TODO
78
+ raise
79
+ end
80
+ end
81
+ end
82
+
83
+ if framework.nil?
84
+ framework = Utils.infer_framework(model.class)
85
+ end
86
+ [framework, model]
87
+ end
88
+ end
89
+
90
+ class ArgumentHandler
91
+ end
92
+
93
+ class Pipeline
94
+ def initialize(
95
+ model,
96
+ tokenizer: nil,
97
+ feature_extractor: nil,
98
+ image_processor: nil,
99
+ modelcard: nil,
100
+ framework: nil,
101
+ task: "",
102
+ device: nil,
103
+ **kwargs
104
+ )
105
+ if framework.nil?
106
+ raise Todo
107
+ end
108
+
109
+ @task = task
110
+ @model = model
111
+ @tokenizer = tokenizer
112
+ @feature_extractor = feature_extractor
113
+ @image_processor = image_processor
114
+ @modelcard = modelcard
115
+ @framework = framework
116
+
117
+ if device.nil?
118
+ if Torch::CUDA.available? || Torch::Backends::MPS.available?
119
+ Transformers.logger.warn(
120
+ "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument" +
121
+ " is passed to the `Pipeline` object. Model will be on CPU."
122
+ )
123
+ end
124
+ end
125
+
126
+ @call_count = 0
127
+ @batch_size = kwargs.delete(:batch_size)
128
+ @num_workers = kwargs.delete(:num_workers)
129
+ @preprocess_params, @forward_params, @postprocess_params = _sanitize_parameters(**kwargs)
130
+ end
131
+
132
+ def torch_dtype
133
+ @model.dtype
134
+ end
135
+
136
+ def check_model_type(supported_models)
137
+ if !supported_models.is_a?(Array)
138
+ supported_models_names = []
139
+ supported_models.each do |_, model_name|
140
+ # Mapping can now contain tuples of models for the same configuration.
141
+ if model_name.is_a?(Array)
142
+ supported_models_names.concat(model_name)
143
+ else
144
+ supported_models_names << model_name
145
+ end
146
+ end
147
+ supported_models = supported_models_names
148
+ end
149
+ if !supported_models.include?(@model.class.name.split("::").last)
150
+ Transformers.logger.error(
151
+ "The model '#{@model.class.name}' is not supported for #{@task}. Supported models are" +
152
+ " #{supported_models}."
153
+ )
154
+ end
155
+ end
156
+
157
+ def get_iterator(
158
+ inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
159
+ )
160
+ if inputs.respond_to?(:size)
161
+ dataset = PipelineDataset.new(inputs, method(:preprocess), preprocess_params)
162
+ else
163
+ if num_workers > 1
164
+ Transformers.logger.warn(
165
+ "For iterable dataset using num_workers>1 is likely to result" +
166
+ " in errors since everything is iterable, setting `num_workers: 1`" +
167
+ " to guarantee correctness."
168
+ )
169
+ num_workers = 1
170
+ end
171
+ dataset = PipelineIterator.new(inputs, method(:preprocess), preprocess_params)
172
+ end
173
+
174
+ # TODO hack by collating feature_extractor and image_processor
175
+ feature_extractor = !@feature_extractor.nil? ? @feature_extractor : @image_processor
176
+ collate_fn = batch_size == 1 ? method(:no_collate_fn) : pad_collate_fn(@tokenizer, feature_extractor)
177
+ dataloader = Torch::Utils::Data::DataLoader.new(dataset, batch_size: batch_size, collate_fn: collate_fn) # num_workers: num_workers,
178
+ model_iterator = PipelineIterator.new(dataloader, method(:forward), forward_params, loader_batch_size: batch_size)
179
+ final_iterator = PipelineIterator.new(model_iterator, method(:postprocess), postprocess_params)
180
+ final_iterator
181
+ end
182
+
183
+ def call(inputs, *args, num_workers: nil, batch_size: nil, **kwargs)
184
+ if args.any?
185
+ Transformers.logger.warn("Ignoring args : #{args}")
186
+ end
187
+
188
+ if num_workers.nil?
189
+ if @num_workers.nil?
190
+ num_workers = 0
191
+ else
192
+ num_workers = @num_workers
193
+ end
194
+ end
195
+ if batch_size.nil?
196
+ if @batch_size.nil?
197
+ batch_size = 1
198
+ else
199
+ batch_size = @batch_size
200
+ end
201
+ end
202
+
203
+ preprocess_params, forward_params, postprocess_params = _sanitize_parameters(**kwargs)
204
+
205
+ preprocess_params = @preprocess_params.merge(preprocess_params)
206
+ forward_params = @forward_params.merge(forward_params)
207
+ postprocess_params = @postprocess_params.merge(postprocess_params)
208
+
209
+ @call_count += 1
210
+ if @call_count > 10 && @framework == "pt" && @device.type == "cuda"
211
+ Transformers.logger.warn(
212
+ "You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a" +
213
+ " dataset"
214
+ )
215
+ end
216
+
217
+ is_dataset = inputs.is_a?(Torch::Utils::Data::Dataset)
218
+ is_generator = inputs.is_a?(Enumerable)
219
+ is_list = inputs.is_a?(Array)
220
+
221
+ _is_iterable = is_dataset || is_generator || is_list
222
+
223
+ # TODO make the get_iterator work also for `tf` (and `flax`).
224
+ can_use_iterator = @framework == "pt" && (is_dataset || is_generator || is_list)
225
+
226
+ if is_list
227
+ if can_use_iterator
228
+ final_iterator = get_iterator(
229
+ inputs, num_workers, batch_size, preprocess_params, forward_params, postprocess_params
230
+ )
231
+ outputs = final_iterator.to_a
232
+ outputs
233
+ else
234
+ run_multi(inputs, preprocess_params, forward_params, postprocess_params)
235
+ end
236
+ else
237
+ run_single(inputs, preprocess_params, forward_params, postprocess_params)
238
+ end
239
+ end
240
+
241
+ private
242
+
243
+ def _sanitize_parameters(**kwargs)
244
+ raise NotImplementedError, "_sanitize_parameters not implemented"
245
+ end
246
+
247
+ def forward(model_inputs, **forward_params)
248
+ _forward(model_inputs, **forward_params)
249
+ end
250
+
251
+ def run_single(inputs, preprocess_params, forward_params, postprocess_params)
252
+ model_inputs = preprocess(inputs, **preprocess_params)
253
+ model_outputs = forward(model_inputs, **forward_params)
254
+ outputs = postprocess(model_outputs, **postprocess_params)
255
+ outputs
256
+ end
257
+
258
+ def no_collate_fn(items)
259
+ if items.length != 1
260
+ raise ArgumentError, "This collate_fn is meant to be used with batch_size=1"
261
+ end
262
+ items[0]
263
+ end
264
+ end
265
+
266
+ class ChunkPipeline < Pipeline
267
+ def run_single(inputs, preprocess_params, forward_params, postprocess_params)
268
+ all_outputs = []
269
+ preprocess(inputs, **preprocess_params) do |model_inputs|
270
+ model_outputs = forward(model_inputs, **forward_params)
271
+ all_outputs << model_outputs
272
+ end
273
+ outputs = postprocess(all_outputs, **postprocess_params)
274
+ outputs
275
+ end
276
+ end
277
+
278
+ class PipelineRegistry
279
+ def initialize(supported_tasks:, task_aliases:)
280
+ @supported_tasks = supported_tasks
281
+ @task_aliases = task_aliases
282
+ end
283
+
284
+ def get_supported_tasks
285
+ supported_task = @supported_tasks.keys + @task_aliases.keys
286
+ supported_task.sort
287
+ end
288
+
289
+ def check_task(task)
290
+ if @task_aliases[task]
291
+ task = @task_aliases[task]
292
+ end
293
+ if @supported_tasks[task]
294
+ targeted_task = @supported_tasks[task]
295
+ return task, targeted_task, nil
296
+ end
297
+
298
+ raise KeyError, "Unknown task #{task}, available tasks are #{get_supported_tasks}"
299
+ end
300
+ end
301
+ end
@@ -0,0 +1,47 @@
1
+ module Transformers
2
+ class FeatureExtractionPipeline < Pipeline
3
+ def _sanitize_parameters(truncation: nil, tokenize_kwargs: nil, return_tensors: nil, **kwargs)
4
+ if tokenize_kwargs.nil?
5
+ tokenize_kwargs = {}
6
+ end
7
+
8
+ if !truncation.nil?
9
+ if tokenize_kwargs.include?(:truncation)
10
+ raise ArgumentError,
11
+ "truncation parameter defined twice (given as keyword argument as well as in tokenize_kwargs)"
12
+ end
13
+ tokenize_kwargs[:truncation] = truncation
14
+ end
15
+
16
+ preprocess_params = tokenize_kwargs
17
+
18
+ postprocess_params = {}
19
+ if !return_tensors.nil?
20
+ postprocess_params[:return_tensors] = return_tensors
21
+ end
22
+
23
+ [preprocess_params, {}, postprocess_params]
24
+ end
25
+
26
+ def preprocess(inputs, **tokenize_kwargs)
27
+ model_inputs = @tokenizer.(inputs, return_tensors: @framework, **tokenize_kwargs)
28
+ model_inputs
29
+ end
30
+
31
+ def _forward(model_inputs)
32
+ model_outputs = @model.(**model_inputs)
33
+ model_outputs
34
+ end
35
+
36
+ def postprocess(model_outputs, return_tensors: false)
37
+ # [0] is the first available tensor, logits or last_hidden_state.
38
+ if return_tensors
39
+ model_outputs[0]
40
+ elsif @framework == "pt"
41
+ model_outputs[0].to_a
42
+ elsif @framework == "tf"
43
+ raise Todo
44
+ end
45
+ end
46
+ end
47
+ end