transformers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,888 @@
1
+ # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ module Transformers
17
+ module ModuleUtilsMixin
18
+ def get_extended_attention_mask(
19
+ attention_mask,
20
+ input_shape,
21
+ device: nil,
22
+ dtype: nil
23
+ )
24
+ if dtype.nil?
25
+ dtype = @dtype
26
+ end
27
+
28
+ if !(attention_mask.dim == 2 && @config.is_decoder)
29
+ # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
30
+ if !device.nil?
31
+ raise Todo
32
+ end
33
+ end
34
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
35
+ # ourselves in which case we just need to make it broadcastable to all heads.
36
+ if attention_mask.dim == 3
37
+ raise Todo
38
+ elsif attention_mask.dim == 2
39
+ # Provided a padding mask of dimensions [batch_size, seq_length]
40
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
41
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
42
+ if @config.is_decoder
43
+ raise Todo
44
+ else
45
+ extended_attention_mask = attention_mask[0.., nil, nil, 0..]
46
+ end
47
+ else
48
+ raise Todo
49
+ end
50
+
51
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
52
+ # masked positions, this operation will create a tensor which is 0.0 for
53
+ # positions we want to attend and the dtype's smallest value for masked positions.
54
+ # Since we are adding it to the raw scores before the softmax, this is
55
+ # effectively the same as removing these entirely.
56
+ extended_attention_mask = extended_attention_mask.to(dtype: dtype) # fp16 compatibility
57
+ # TODO use Torch.finfo
58
+ extended_attention_mask = (1.0 - extended_attention_mask) * -3.40282e+38
59
+ extended_attention_mask
60
+ end
61
+
62
+ def get_head_mask(head_mask, num_hidden_layers, is_attention_chunked: false)
63
+ if !head_mask.nil?
64
+ head_mask = _convert_head_mask_to_5d(head_mask, num_hidden_layers)
65
+ if is_attention_chunked == true
66
+ head_mask = head_mask.unsqueeze(-1)
67
+ end
68
+ else
69
+ head_mask = [nil] * num_hidden_layers
70
+ end
71
+
72
+ head_mask
73
+ end
74
+ end
75
+
76
+ class PreTrainedModel < Torch::NN::Module
77
+ extend ClassAttribute
78
+ include ModuleUtilsMixin
79
+
80
+ class_attribute :config_class
81
+ class_attribute :base_model_prefix, ""
82
+ class_attribute :main_input_name, "input_ids"
83
+ class_attribute :model_tags
84
+
85
+ class_attribute :_tied_weights_keys
86
+
87
+ attr_reader :config
88
+
89
+ def dummy_inputs
90
+ raise Todo
91
+ end
92
+
93
+ def framework
94
+ "pt"
95
+ end
96
+
97
+ def initialize(config, *inputs, **kwargs)
98
+ super()
99
+ @config = config
100
+ end
101
+
102
+ def post_init
103
+ init_weights
104
+ _backward_compatibility_gradient_checkpointing
105
+ end
106
+
107
+ def dequantize
108
+ raise Todo
109
+ end
110
+
111
+ def _backward_compatibility_gradient_checkpointing
112
+ # TODO
113
+ end
114
+
115
+ def base_model
116
+ instance_variable_get("@#{self.class.base_model_prefix}") || self
117
+ end
118
+
119
+ def can_generate
120
+ # TODO improve
121
+ false
122
+ end
123
+
124
+ def get_input_embeddings
125
+ raise Todo
126
+ end
127
+
128
+ def set_input_embeddings(value)
129
+ raise Todo
130
+ end
131
+
132
+ def get_output_embeddings
133
+ nil # Overwrite for models with output embeddings
134
+ end
135
+
136
+ def _init_weights(mod)
137
+ # pass
138
+ end
139
+
140
+ def _initialize_weights(mod)
141
+ _init_weights(mod)
142
+ end
143
+
144
+ def tie_weights
145
+ if @config.tie_word_embeddings != false
146
+ output_embeddings = get_output_embeddings
147
+ if !output_embeddings.nil?
148
+ raise Todo
149
+ end
150
+ end
151
+
152
+ if @config.is_encoder_decoder && @config.tie_encoder_decoder
153
+ raise Todo
154
+ end
155
+
156
+ modules.each do |mod|
157
+ if mod.respond_to?(:_tie_weights)
158
+ mod._tie_weights
159
+ end
160
+ end
161
+ end
162
+
163
+ def init_weights
164
+ # Prune heads if needed
165
+ if @config.pruned_heads
166
+ prune_heads(@config.pruned_heads)
167
+ end
168
+
169
+ if true
170
+ # Initialize weights
171
+ apply(method(:_initialize_weights))
172
+
173
+ # Tie weights should be skipped when not initializing all weights
174
+ # since from_pretrained(...) calls tie weights anyways
175
+ tie_weights
176
+ end
177
+ end
178
+
179
+ def prune_heads(heads_to_prune)
180
+ # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
181
+ heads_to_prune.each do |layer, heads|
182
+ union_heads = Set.new(@config.pruned_heads.fetch(layer, [])) | Set.new(heads)
183
+ @config.pruned_heads[layer] = union_heads.to_a # Unfortunately we have to store it as list for JSON
184
+ end
185
+
186
+ base_model._prune_heads(heads_to_prune)
187
+ end
188
+
189
+ def warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
190
+ if !attention_mask.nil? || @config.pad_token_id.nil?
191
+ return
192
+ end
193
+
194
+ # Check only the first and last input IDs to reduce overhead.
195
+ if input_ids[0.., [-1, 0]].include?(@config.pad_token_id)
196
+ raise Todo
197
+ end
198
+ end
199
+
200
+ class << self
201
+ def from_pretrained(
202
+ pretrained_model_name_or_path,
203
+ *model_args,
204
+ config: nil,
205
+ cache_dir: nil,
206
+ ignore_mismatched_sizes: false,
207
+ force_download: false,
208
+ local_files_only: false,
209
+ token: nil,
210
+ revision: "main",
211
+ use_safetensors: nil,
212
+ **kwargs
213
+ )
214
+ state_dict = kwargs.delete(:state_dict)
215
+ from_tf = kwargs.delete(:from_tf) { false }
216
+ from_flax = kwargs.delete(:from_flax) { false }
217
+ resume_download = kwargs.delete(:resume_download) { false }
218
+ proxies = kwargs.delete(:proxies)
219
+ output_loading_info = kwargs.delete(:output_loading_info) { false }
220
+ _use_auth_token = kwargs.delete(:use_auth_token)
221
+ trust_remote_code = kwargs.delete(:trust_remote_code)
222
+ _ = kwargs.delete(:mirror)
223
+ from_pipeline = kwargs.delete(:_from_pipeline)
224
+ from_auto_class = kwargs.delete(:_from_auto) { false }
225
+ _fast_init = kwargs.delete(:_fast_init) { true }
226
+ torch_dtype = kwargs.delete(:torch_dtype)
227
+ low_cpu_mem_usage = kwargs.delete(:low_cpu_mem_usage)
228
+ device_map = kwargs.delete(:device_map)
229
+ _max_memory = kwargs.delete(:max_memory)
230
+ offload_folder = kwargs.delete(:offload_folder)
231
+ offload_state_dict = kwargs.delete(:offload_state_dict) { false }
232
+ load_in_8bit = kwargs.delete(:load_in_8bit) { false }
233
+ load_in_4bit = kwargs.delete(:load_in_4bit) { false }
234
+ quantization_config = kwargs.delete(:quantization_config)
235
+ subfolder = kwargs.delete(:subfolder) { "" }
236
+ commit_hash = kwargs.delete(:_commit_hash)
237
+ variant = kwargs.delete(:variant)
238
+ _adapter_kwargs = kwargs.delete(:adapter_kwargs) { {} }
239
+ _adapter_name = kwargs.delete(:adapter_name) { "default" }
240
+ _use_flash_attention_2 = kwargs.delete(:use_flash_attention_2) { false }
241
+
242
+ if use_safetensors.nil? && !is_safetensors_available
243
+ use_safetensors = false
244
+ end
245
+ if trust_remote_code
246
+ Transformers.logger.warn(
247
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" +
248
+ " ignored."
249
+ )
250
+ end
251
+
252
+ if commit_hash.nil?
253
+ if !config.is_a?(PretrainedConfig)
254
+ # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
255
+ resolved_config_file =
256
+ Utils::Hub.cached_file(
257
+ pretrained_model_name_or_path,
258
+ CONFIG_NAME,
259
+ cache_dir: cache_dir,
260
+ force_download: force_download,
261
+ resume_download: resume_download,
262
+ proxies: proxies,
263
+ local_files_only: local_files_only,
264
+ token: token,
265
+ revision: revision,
266
+ subfolder: subfolder,
267
+ _raise_exceptions_for_gated_repo: false,
268
+ _raise_exceptions_for_missing_entries: false,
269
+ _raise_exceptions_for_connection_errors: false,
270
+ )
271
+ commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
272
+ else
273
+ commit_hash = config._commit_hash
274
+ end
275
+ end
276
+
277
+ if !device_map.nil?
278
+ raise Todo
279
+ end
280
+
281
+ # handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
282
+ if load_in_4bit || load_in_8bit
283
+ raise Todo
284
+ end
285
+
286
+ from_pt = !(from_tf || from_flax)
287
+
288
+ user_agent = {file_type: "model", framework: "pytorch", from_auto_class: from_auto_class}
289
+ if !from_pipeline.nil?
290
+ user_agent[:using_pipeline] = from_pipeline
291
+ end
292
+
293
+ if Utils::Hub.is_offline_mode && !local_files_only
294
+ Transformers.logger.info "Offline mode: forcing local_files_only: true"
295
+ local_files_only = true
296
+ end
297
+
298
+ # Load config if we don't provide a configuration
299
+ if !config.is_a?(PretrainedConfig)
300
+ config_path = !config.nil? ? config : pretrained_model_name_or_path
301
+ config, model_kwargs =
302
+ config_class.from_pretrained(
303
+ config_path,
304
+ cache_dir: cache_dir,
305
+ return_unused_kwargs: true,
306
+ force_download: force_download,
307
+ resume_download: resume_download,
308
+ proxies: proxies,
309
+ local_files_only: local_files_only,
310
+ token: token,
311
+ revision: revision,
312
+ subfolder: subfolder,
313
+ _from_auto: from_auto_class,
314
+ _from_pipeline: from_pipeline,
315
+ **kwargs
316
+ )
317
+ else
318
+ # In case one passes a config to `from_pretrained` + "attn_implementation"
319
+ # override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
320
+ # Please see: https://github.com/huggingface/transformers/issues/28038
321
+
322
+ # Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
323
+ # we pop attn_implementation from the kwargs but this handles the case where users
324
+ # passes manually the config to `from_pretrained`.
325
+ config = Copy.deepcopy(config)
326
+
327
+ kwarg_attn_imp = kwargs.delete(:attn_implementation)
328
+ if !kwarg_attn_imp.nil? && config._attn_implementation != kwarg_attn_imp
329
+ config._attn_implementation = kwarg_attn_imp
330
+ end
331
+ model_kwargs = kwargs
332
+ end
333
+
334
+ pre_quantized = false # !config.quantization_config.nil?
335
+ if pre_quantized || !quantization_config.nil?
336
+ raise Todo
337
+ else
338
+ hf_quantizer = nil
339
+ end
340
+
341
+ if !hf_quantizer.nil?
342
+ raise Todo
343
+ end
344
+
345
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
346
+ # index of the files.
347
+ is_sharded = false
348
+ sharded_metadata = nil
349
+ # Load model
350
+ _loading_info = nil
351
+
352
+ # Keep in fp32 modules
353
+ keep_in_fp32_modules = nil
354
+ _use_keep_in_fp32_modules = false
355
+
356
+ resolved_archive_file = nil
357
+ if !pretrained_model_name_or_path.nil?
358
+ pretrained_model_name_or_path = pretrained_model_name_or_path.to_s
359
+ is_local = Dir.exist?(pretrained_model_name_or_path)
360
+ if is_local
361
+ raise Todo
362
+ elsif File.exist?(File.join(subfolder, pretrained_model_name_or_path))
363
+ _archive_file = pretrained_model_name_or_path
364
+ is_local = true
365
+ else
366
+ # set correct filename
367
+ if use_safetensors != false
368
+ filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
369
+ else
370
+ filename = _add_variant(WEIGHTS_NAME, variant)
371
+ end
372
+
373
+ # Load from URL or cache if already cached
374
+ cached_file_kwargs = {
375
+ cache_dir: cache_dir,
376
+ force_download: force_download,
377
+ proxies: proxies,
378
+ resume_download: resume_download,
379
+ local_files_only: local_files_only,
380
+ token: token,
381
+ user_agent: user_agent,
382
+ revision: revision,
383
+ subfolder: subfolder,
384
+ _raise_exceptions_for_gated_repo: false,
385
+ _raise_exceptions_for_missing_entries: false,
386
+ _commit_hash: commit_hash
387
+ }
388
+ resolved_archive_file = Utils::Hub.cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
389
+
390
+ # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
391
+ # result when internet is up, the repo and revision exist, but the file does not.
392
+ if resolved_archive_file.nil? && filename == _add_variant(SAFE_WEIGHTS_NAME, variant)
393
+ # Maybe the checkpoint is sharded, we try to grab the index name in this case.
394
+ resolved_archive_file = Utils::Hub.cached_file(
395
+ pretrained_model_name_or_path,
396
+ _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
397
+ **cached_file_kwargs,
398
+ )
399
+ if !resolved_archive_file.nil?
400
+ is_sharded = true
401
+ elsif use_safetensors
402
+ raise Todo
403
+ else
404
+ # This repo has no safetensors file of any kind, we switch to PyTorch.
405
+ filename = _add_variant(WEIGHTS_NAME, variant)
406
+ resolved_archive_file = Utils::Hub.cached_file(
407
+ pretrained_model_name_or_path, filename, **cached_file_kwargs
408
+ )
409
+ end
410
+ end
411
+ if resolved_archive_file.nil? && filename == _add_variant(WEIGHTS_NAME, variant)
412
+ # Maybe the checkpoint is sharded, we try to grab the index name in this case.
413
+ resolved_archive_file = Utils::Hub.cached_file(
414
+ pretrained_model_name_or_path,
415
+ _add_variant(WEIGHTS_INDEX_NAME, variant),
416
+ **cached_file_kwargs
417
+ )
418
+ if !resolved_archive_file.nil?
419
+ is_sharded = true
420
+ end
421
+ end
422
+ if !local_files_only && !Utils::Hub.is_offline_mode
423
+ if !resolved_archive_file.nil?
424
+ if [WEIGHTS_NAME, WEIGHTS_INDEX_NAME].include?(filename)
425
+ # If the PyTorch file was found, check if there is a safetensors file on the repository
426
+ # If there is no safetensors file on the repositories, start an auto conversion
427
+ _safe_weights_name = is_sharded ? SAFE_WEIGHTS_INDEX_NAME : SAFE_WEIGHTS_NAME
428
+ has_file_kwargs = {
429
+ revision: revision,
430
+ proxies: proxies,
431
+ token: token,
432
+ cache_dir: cache_dir,
433
+ local_files_only: local_files_only
434
+ }
435
+ cached_file_kwargs = {
436
+ cache_dir: cache_dir,
437
+ force_download: force_download,
438
+ resume_download: resume_download,
439
+ local_files_only: local_files_only,
440
+ user_agent: user_agent,
441
+ subfolder: subfolder,
442
+ _raise_exceptions_for_gated_repo: false,
443
+ _raise_exceptions_for_missing_entries: false,
444
+ _commit_hash: commit_hash,
445
+ **has_file_kwargs
446
+ }
447
+ # skip auto conversion
448
+ # if !Utils::Hub.has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
449
+ # end
450
+ end
451
+ else
452
+ raise Todo
453
+ end
454
+ end
455
+
456
+ if is_local
457
+ Transformers.logger.info("loading weights file #{archive_file}")
458
+ resolved_archive_file = archive_file
459
+ else
460
+ Transformers.logger.info("loading weights file #{filename} from cache at #{resolved_archive_file}")
461
+ end
462
+ end
463
+ else
464
+ resolved_archive_file = nil
465
+ end
466
+
467
+ # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
468
+ if is_sharded
469
+ raise Todo
470
+ end
471
+
472
+ metadata = nil
473
+ if is_safetensors_available && resolved_archive_file.is_a?(String) && resolved_archive_file.end_with?(".safetensors")
474
+ Safetensors.safe_open(resolved_archive_file, framework: "pt") do |f|
475
+ metadata = f.metadata
476
+ end
477
+
478
+ if metadata["format"] == "pt"
479
+ # do nothing
480
+ else
481
+ raise ArgumentError,
482
+ "Incompatible safetensors file. File metadata is not ['pt'] but #{metadata["format"]}"
483
+ end
484
+ end
485
+
486
+ from_pt = !(from_tf || from_flax)
487
+
488
+ # load pt weights early so that we know which dtype to init the model under
489
+ if from_pt
490
+ if !is_sharded && state_dict.nil?
491
+ # Time to load the checkpoint
492
+ state_dict = load_state_dict(resolved_archive_file)
493
+ end
494
+
495
+ # set dtype to instantiate the model under:
496
+ # 1. If torch_dtype is not None, we use that dtype
497
+ # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
498
+ # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
499
+ # we also may have config.torch_dtype available, but we won't rely on it till v5
500
+ dtype_orig = nil
501
+
502
+ if !torch_dtype.nil?
503
+ raise Todo
504
+ end
505
+
506
+ if is_sharded
507
+ loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
508
+ else
509
+ loaded_state_dict_keys = state_dict.keys
510
+ end
511
+ end
512
+
513
+ config.name_or_path = pretrained_model_name_or_path
514
+
515
+ model_kwargs = {}
516
+ model = new(config, *model_args, **model_kwargs)
517
+
518
+ # make sure we use the model's config since the __init__ call might have copied it
519
+ config = model.config
520
+
521
+ if device_map.is_a?(String)
522
+ raise Todo
523
+ elsif !device_map.nil?
524
+ raise Todo
525
+ end
526
+
527
+ if from_pt
528
+ # restore default dtype
529
+ if !dtype_orig.nil?
530
+ Torch.set_default_dtype(dtype_orig)
531
+ end
532
+
533
+ model, _missing_keys, _unexpected_keys, _mismatched_keys, _offload_index, _error_msgs =
534
+ _load_pretrained_model(
535
+ model,
536
+ state_dict,
537
+ loaded_state_dict_keys, # XXX: rename?
538
+ resolved_archive_file,
539
+ pretrained_model_name_or_path,
540
+ ignore_mismatched_sizes: ignore_mismatched_sizes,
541
+ sharded_metadata: sharded_metadata,
542
+ _fast_init: _fast_init,
543
+ low_cpu_mem_usage: low_cpu_mem_usage,
544
+ device_map: device_map,
545
+ offload_folder: offload_folder,
546
+ offload_state_dict: offload_state_dict,
547
+ dtype: torch_dtype,
548
+ hf_quantizer: hf_quantizer,
549
+ keep_in_fp32_modules: keep_in_fp32_modules
550
+ )
551
+ end
552
+
553
+ # make sure token embedding weights are still tied if needed
554
+ model.tie_weights
555
+
556
+ # Set model in evaluation mode to deactivate DropOut modules by default
557
+ model.eval
558
+
559
+ # If it is a model with generation capabilities, attempt to load the generation config
560
+ if model.can_generate && !pretrained_model_name_or_path.nil?
561
+ raise Todo
562
+ end
563
+
564
+ # Dispatch model with hooks on all devices if necessary
565
+ if !device_map.nil?
566
+ raise Todo
567
+ end
568
+
569
+ if !hf_quantizer.nil?
570
+ raise Todo
571
+ end
572
+
573
+ if output_loading_info
574
+ raise Todo
575
+ end
576
+
577
+ model
578
+ end
579
+
580
+ private
581
+
582
+ def _load_pretrained_model(
583
+ model,
584
+ state_dict,
585
+ loaded_keys,
586
+ resolved_archive_file,
587
+ pretrained_model_name_or_path,
588
+ ignore_mismatched_sizes: false,
589
+ sharded_metadata: nil,
590
+ _fast_init: true,
591
+ low_cpu_mem_usage: false,
592
+ device_map: nil,
593
+ offload_folder: nil,
594
+ offload_state_dict: nil,
595
+ dtype: nil,
596
+ hf_quantizer: nil,
597
+ keep_in_fp32_modules: nil
598
+ )
599
+ is_safetensors = false
600
+
601
+ _is_sharded_safetensors = is_safetensors && !sharded_metadata.nil?
602
+
603
+ # tie the model weights before retrieving the state_dict
604
+ model.tie_weights
605
+
606
+ # Retrieve missing & unexpected_keys
607
+ model_state_dict = model.state_dict
608
+ expected_keys = model_state_dict.keys
609
+ prefix = model.class.base_model_prefix
610
+
611
+ _fix_key = lambda do |key|
612
+ if key.include?("beta")
613
+ key.gsub("beta", "bias")
614
+ end
615
+ if key.include?("gamma")
616
+ key.gsub("gamma", "weight")
617
+ else
618
+ key
619
+ end
620
+ end
621
+
622
+ original_loaded_keys = loaded_keys
623
+ loaded_keys = loaded_keys.map { |key| _fix_key.(key) }
624
+
625
+ if prefix.length > 0
626
+ has_prefix_module = loaded_keys.any? { |s| s.start_with?(prefix) }
627
+ expects_prefix_module = expected_keys.any? { |s| s.start_with?(prefix) }
628
+ else
629
+ has_prefix_module = false
630
+ expects_prefix_module = false
631
+ end
632
+
633
+ # key re-naming operations are never done on the keys
634
+ # that are loaded, but always on the keys of the newly initialized model
635
+ remove_prefix_from_model = !has_prefix_module && expects_prefix_module
636
+ add_prefix_to_model = has_prefix_module && !expects_prefix_module
637
+
638
+ if remove_prefix_from_model
639
+ _prefix = "#{prefix}."
640
+ expected_keys_not_prefixed = expected_keys.select { |s| !s.start_with?(_prefix) }
641
+ expected_keys = expected_keys.map { |s| s.start_with?(_prefix) ? s[_prefix.length..] : s }
642
+ elsif add_prefix_to_model
643
+ expected_keys = expected_keys.map { |s| [prefix, s].join(".") }
644
+ end
645
+
646
+ missing_keys = (Set.new(expected_keys) - Set.new(loaded_keys)).sort
647
+ unexpected_keys = Set.new(loaded_keys) - Set.new(expected_keys)
648
+ # Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
649
+ # buffers
650
+ model_buffers = model.named_buffers(recurse: true).keys
651
+ if remove_prefix_from_model
652
+ raise Todo
653
+ elsif add_prefix_to_model
654
+ model_buffers = model_buffers.map { |key| [prefix, key].join(".") }
655
+ end
656
+ unexpected_keys = (unexpected_keys - model_buffers).sort
657
+
658
+ model.tie_weights
659
+ if device_map.nil?
660
+ ptrs = Hash.new { |hash, key| hash[key] = [] }
661
+
662
+ model.state_dict.each do |name, tensor|
663
+ # TODO fix
664
+ id_tensor = tensor.object_id # id_tensor_storage(tensor)
665
+ ptrs[id_tensor] << name
666
+ end
667
+
668
+ # These are all the pointers of shared tensors.
669
+ tied_params = ptrs.select { |_, names| names.length > 1 }.values
670
+ else
671
+ raise Todo
672
+ end
673
+
674
+ tied_params.each do |group|
675
+ if remove_prefix_from_model
676
+ group = group.map { |key| key.delete_prefix(_prefix) }
677
+ elsif add_prefix_to_model
678
+ group = group.map { |key| [prefix, key].join(".") }
679
+ end
680
+ missing_in_group = missing_keys.select { |k| group.include?(k) }
681
+ if missing_in_group.length > 0 && missing_in_group.length < group.length
682
+ missing_keys = missing_keys.select { |k| !missing_in_group.include?(k) }
683
+ end
684
+ end
685
+
686
+ # Make sure we are able to load base models as well as derived models (with heads)
687
+ start_prefix = ""
688
+ model_to_load = model
689
+ if base_model_prefix.length > 0 && !model.instance_variable_defined?("@#{base_model_prefix}") && has_prefix_module
690
+ start_prefix = base_model_prefix + "."
691
+ end
692
+ if base_model_prefix.length > 0 && model.instance_variable_defined?("@#{base_model_prefix}") && !has_prefix_module
693
+ model_to_load = model.instance_variable_get("@#{base_model_prefix}")
694
+ base_model_expected_keys = model_to_load.state_dict.keys
695
+ if loaded_keys.any? { |key| expected_keys_not_prefixed.include?(key) && !base_model_expected_keys.include?(key) }
696
+ raise ArgumentError, "The state dictionary of the model you are trying to load is corrupted. Are you sure it was properly saved?"
697
+ end
698
+ if !device_map.nil?
699
+ raise Todo
700
+ end
701
+ end
702
+
703
+ _find_mismatched_keys = lambda do |state_dict, model_state_dict, loaded_keys, add_prefix_to_model, remove_prefix_from_model, ignore_mismatched_sizes|
704
+ mismatched_keys = []
705
+ if ignore_mismatched_sizes
706
+ loaded_keys.each do |checkpoint_key|
707
+ # If the checkpoint is sharded, we may not have the key here.
708
+ if !state_dict.include?(checkpoint_key)
709
+ next
710
+ end
711
+ model_key = checkpoint_key
712
+ if remove_prefix_from_model
713
+ # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
714
+ model_key = "#{prefix}.#{checkpoint_key}"
715
+ elsif add_prefix_to_model
716
+ # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
717
+ model_key = checkpoint_key.split(".")[1..].join(".")
718
+ end
719
+
720
+ if model_state_dict.include?(model_key) && state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
721
+ raise Todo
722
+ end
723
+ end
724
+ end
725
+ mismatched_keys
726
+ end
727
+
728
+ if !resolved_archive_file.nil?
729
+ _folder = File.dirname(resolved_archive_file)
730
+ else
731
+ _folder = nil
732
+ end
733
+
734
+ if !device_map.nil? && is_safetensors
735
+ raise Todo
736
+ end
737
+
738
+ if !state_dict.nil?
739
+ # Whole checkpoint
740
+ mismatched_keys =
741
+ _find_mismatched_keys.(
742
+ state_dict,
743
+ model_state_dict,
744
+ original_loaded_keys,
745
+ add_prefix_to_model,
746
+ remove_prefix_from_model,
747
+ ignore_mismatched_sizes
748
+ )
749
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
750
+ offload_index = nil
751
+ else
752
+ raise Todo
753
+ end
754
+
755
+ if error_msgs.length > 0
756
+ raise Todo
757
+ end
758
+
759
+ if unexpected_keys.length > 0
760
+ archs = model.config.architectures.nil? ? [] : model.config.architectures
761
+ warner = archs.include?(model.class.name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
762
+ warner.(
763
+ "Some weights of the model checkpoint at #{pretrained_model_name_or_path} were not used when" +
764
+ " initializing #{model.class.name}: #{unexpected_keys}\n- This IS expected if you are" +
765
+ " initializing #{model.class.name} from the checkpoint of a model trained on another task or" +
766
+ " with another architecture (e.g. initializing a BertForSequenceClassification model from a" +
767
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing" +
768
+ " #{model.class.name} from the checkpoint of a model that you expect to be exactly identical" +
769
+ " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
770
+ )
771
+ else
772
+ Transformers.logger.info("All model checkpoint weights were used when initializing #{model.class.name}.\n")
773
+ end
774
+ if missing_keys.length > 0
775
+ Transformers.logger.info("Some weights of #{model.class.name} were not initialized from the model checkpoint at #{pretrained_model_name_or_path} and are newly initialized: #{missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.")
776
+ elsif mismatched_keys.length == 0
777
+ Transformers.logger.info(
778
+ "All the weights of #{model.class.name} were initialized from the model checkpoint at" +
779
+ " #{pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" +
780
+ " was trained on, you can already use #{model.class.name} for predictions without further" +
781
+ " training."
782
+ )
783
+ end
784
+ if mismatched_keys.length > 0
785
+ raise Todo
786
+ end
787
+
788
+ [model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs]
789
+ end
790
+
791
+ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
792
+ # Convert old format to new format if needed from a PyTorch state_dict
793
+ old_keys = []
794
+ new_keys = []
795
+ state_dict.each_key do |key|
796
+ new_key = nil
797
+ if key.include?("gamma")
798
+ new_key = key.gsub("gamma", "weight")
799
+ end
800
+ if key.include?("beta")
801
+ new_key = key.gsub("beta", "bias")
802
+ end
803
+ if new_key
804
+ old_keys << key
805
+ new_keys << new_key
806
+ end
807
+ end
808
+ old_keys.zip(new_keys) do |old_key, new_key|
809
+ state_dict[new_key] = state_dict.delete(old_key)
810
+ end
811
+
812
+ # copy state_dict so _load_from_state_dict can modify it
813
+ metadata = nil #getattr(state_dict, "_metadata", None)
814
+ state_dict = state_dict.dup
815
+ if !metadata.nil?
816
+ state_dict._metadata = metadata
817
+ end
818
+
819
+ error_msgs = []
820
+
821
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
822
+ # so we need to apply the function recursively.
823
+ load = lambda do |mod, state_dict, prefix|
824
+ local_metadata = metadata.nil? ? {} : metadata.fetch(prefix[...-1], {})
825
+ args = [state_dict, prefix, local_metadata, true, [], [], error_msgs]
826
+ # Parameters of module and children will start with prefix. We can exit early if there are none in this
827
+ # state_dict
828
+ if state_dict.any? { |key, _| key.start_with?(prefix) }
829
+ mod.send(:load_from_state_dict, *args)
830
+ end
831
+
832
+ mod.named_children.each do |name, child|
833
+ if !child.nil?
834
+ load.(child, state_dict, prefix + name + ".")
835
+ end
836
+ end
837
+ end
838
+
839
+ load.(model_to_load, state_dict, start_prefix)
840
+
841
+ error_msgs
842
+ end
843
+
844
+ def is_safetensors_available
845
+ defined?(Safetensors)
846
+ end
847
+
848
+ def load_state_dict(checkpoint_file)
849
+ if checkpoint_file.end_with?(".safetensors") && is_safetensors_available
850
+ # Check format of the archive
851
+ metadata = nil
852
+ Safetensors.safe_open(checkpoint_file, framework: "pt") do |f|
853
+ metadata = f.metadata
854
+ end
855
+ if !["pt", "tf", "flax"].include?(metadata["format"])
856
+ raise OSError, "The safetensors archive passed at #{checkpoint_file} does not contain the valid metadata. Make sure you save your model with the `save_pretrained` method."
857
+ end
858
+ return Safetensors::Torch.load_file(checkpoint_file)
859
+ end
860
+ begin
861
+ _map_location = "cpu"
862
+ _extra_args = {}
863
+ _weights_only_kwarg = {weights_only: true}
864
+ Torch.load(
865
+ checkpoint_file,
866
+ # Torch.rb does not currently support additional options
867
+ # map_location: map_location,
868
+ # **weights_only_kwarg,
869
+ # **extra_args
870
+ )
871
+ rescue => e
872
+ # TODO improve
873
+ raise e
874
+ end
875
+ end
876
+
877
+ def _add_variant(weights_name, variant)
878
+ if !variant.nil?
879
+ splits = weights_name.split(".")
880
+ splits = splits[...-1] + [variant] + splits[-1..]
881
+ weights_name = splits.join(".")
882
+ end
883
+
884
+ weights_name
885
+ end
886
+ end
887
+ end
888
+ end