transformers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,888 @@
|
|
1
|
+
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
|
2
|
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
module Transformers
|
17
|
+
module ModuleUtilsMixin
|
18
|
+
def get_extended_attention_mask(
|
19
|
+
attention_mask,
|
20
|
+
input_shape,
|
21
|
+
device: nil,
|
22
|
+
dtype: nil
|
23
|
+
)
|
24
|
+
if dtype.nil?
|
25
|
+
dtype = @dtype
|
26
|
+
end
|
27
|
+
|
28
|
+
if !(attention_mask.dim == 2 && @config.is_decoder)
|
29
|
+
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
|
30
|
+
if !device.nil?
|
31
|
+
raise Todo
|
32
|
+
end
|
33
|
+
end
|
34
|
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
35
|
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
36
|
+
if attention_mask.dim == 3
|
37
|
+
raise Todo
|
38
|
+
elsif attention_mask.dim == 2
|
39
|
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
40
|
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
41
|
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
42
|
+
if @config.is_decoder
|
43
|
+
raise Todo
|
44
|
+
else
|
45
|
+
extended_attention_mask = attention_mask[0.., nil, nil, 0..]
|
46
|
+
end
|
47
|
+
else
|
48
|
+
raise Todo
|
49
|
+
end
|
50
|
+
|
51
|
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
52
|
+
# masked positions, this operation will create a tensor which is 0.0 for
|
53
|
+
# positions we want to attend and the dtype's smallest value for masked positions.
|
54
|
+
# Since we are adding it to the raw scores before the softmax, this is
|
55
|
+
# effectively the same as removing these entirely.
|
56
|
+
extended_attention_mask = extended_attention_mask.to(dtype: dtype) # fp16 compatibility
|
57
|
+
# TODO use Torch.finfo
|
58
|
+
extended_attention_mask = (1.0 - extended_attention_mask) * -3.40282e+38
|
59
|
+
extended_attention_mask
|
60
|
+
end
|
61
|
+
|
62
|
+
def get_head_mask(head_mask, num_hidden_layers, is_attention_chunked: false)
|
63
|
+
if !head_mask.nil?
|
64
|
+
head_mask = _convert_head_mask_to_5d(head_mask, num_hidden_layers)
|
65
|
+
if is_attention_chunked == true
|
66
|
+
head_mask = head_mask.unsqueeze(-1)
|
67
|
+
end
|
68
|
+
else
|
69
|
+
head_mask = [nil] * num_hidden_layers
|
70
|
+
end
|
71
|
+
|
72
|
+
head_mask
|
73
|
+
end
|
74
|
+
end
|
75
|
+
|
76
|
+
class PreTrainedModel < Torch::NN::Module
|
77
|
+
extend ClassAttribute
|
78
|
+
include ModuleUtilsMixin
|
79
|
+
|
80
|
+
class_attribute :config_class
|
81
|
+
class_attribute :base_model_prefix, ""
|
82
|
+
class_attribute :main_input_name, "input_ids"
|
83
|
+
class_attribute :model_tags
|
84
|
+
|
85
|
+
class_attribute :_tied_weights_keys
|
86
|
+
|
87
|
+
attr_reader :config
|
88
|
+
|
89
|
+
def dummy_inputs
|
90
|
+
raise Todo
|
91
|
+
end
|
92
|
+
|
93
|
+
def framework
|
94
|
+
"pt"
|
95
|
+
end
|
96
|
+
|
97
|
+
def initialize(config, *inputs, **kwargs)
|
98
|
+
super()
|
99
|
+
@config = config
|
100
|
+
end
|
101
|
+
|
102
|
+
def post_init
|
103
|
+
init_weights
|
104
|
+
_backward_compatibility_gradient_checkpointing
|
105
|
+
end
|
106
|
+
|
107
|
+
def dequantize
|
108
|
+
raise Todo
|
109
|
+
end
|
110
|
+
|
111
|
+
def _backward_compatibility_gradient_checkpointing
|
112
|
+
# TODO
|
113
|
+
end
|
114
|
+
|
115
|
+
def base_model
|
116
|
+
instance_variable_get("@#{self.class.base_model_prefix}") || self
|
117
|
+
end
|
118
|
+
|
119
|
+
def can_generate
|
120
|
+
# TODO improve
|
121
|
+
false
|
122
|
+
end
|
123
|
+
|
124
|
+
def get_input_embeddings
|
125
|
+
raise Todo
|
126
|
+
end
|
127
|
+
|
128
|
+
def set_input_embeddings(value)
|
129
|
+
raise Todo
|
130
|
+
end
|
131
|
+
|
132
|
+
def get_output_embeddings
|
133
|
+
nil # Overwrite for models with output embeddings
|
134
|
+
end
|
135
|
+
|
136
|
+
def _init_weights(mod)
|
137
|
+
# pass
|
138
|
+
end
|
139
|
+
|
140
|
+
def _initialize_weights(mod)
|
141
|
+
_init_weights(mod)
|
142
|
+
end
|
143
|
+
|
144
|
+
def tie_weights
|
145
|
+
if @config.tie_word_embeddings != false
|
146
|
+
output_embeddings = get_output_embeddings
|
147
|
+
if !output_embeddings.nil?
|
148
|
+
raise Todo
|
149
|
+
end
|
150
|
+
end
|
151
|
+
|
152
|
+
if @config.is_encoder_decoder && @config.tie_encoder_decoder
|
153
|
+
raise Todo
|
154
|
+
end
|
155
|
+
|
156
|
+
modules.each do |mod|
|
157
|
+
if mod.respond_to?(:_tie_weights)
|
158
|
+
mod._tie_weights
|
159
|
+
end
|
160
|
+
end
|
161
|
+
end
|
162
|
+
|
163
|
+
def init_weights
|
164
|
+
# Prune heads if needed
|
165
|
+
if @config.pruned_heads
|
166
|
+
prune_heads(@config.pruned_heads)
|
167
|
+
end
|
168
|
+
|
169
|
+
if true
|
170
|
+
# Initialize weights
|
171
|
+
apply(method(:_initialize_weights))
|
172
|
+
|
173
|
+
# Tie weights should be skipped when not initializing all weights
|
174
|
+
# since from_pretrained(...) calls tie weights anyways
|
175
|
+
tie_weights
|
176
|
+
end
|
177
|
+
end
|
178
|
+
|
179
|
+
def prune_heads(heads_to_prune)
|
180
|
+
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
|
181
|
+
heads_to_prune.each do |layer, heads|
|
182
|
+
union_heads = Set.new(@config.pruned_heads.fetch(layer, [])) | Set.new(heads)
|
183
|
+
@config.pruned_heads[layer] = union_heads.to_a # Unfortunately we have to store it as list for JSON
|
184
|
+
end
|
185
|
+
|
186
|
+
base_model._prune_heads(heads_to_prune)
|
187
|
+
end
|
188
|
+
|
189
|
+
def warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
190
|
+
if !attention_mask.nil? || @config.pad_token_id.nil?
|
191
|
+
return
|
192
|
+
end
|
193
|
+
|
194
|
+
# Check only the first and last input IDs to reduce overhead.
|
195
|
+
if input_ids[0.., [-1, 0]].include?(@config.pad_token_id)
|
196
|
+
raise Todo
|
197
|
+
end
|
198
|
+
end
|
199
|
+
|
200
|
+
class << self
|
201
|
+
def from_pretrained(
|
202
|
+
pretrained_model_name_or_path,
|
203
|
+
*model_args,
|
204
|
+
config: nil,
|
205
|
+
cache_dir: nil,
|
206
|
+
ignore_mismatched_sizes: false,
|
207
|
+
force_download: false,
|
208
|
+
local_files_only: false,
|
209
|
+
token: nil,
|
210
|
+
revision: "main",
|
211
|
+
use_safetensors: nil,
|
212
|
+
**kwargs
|
213
|
+
)
|
214
|
+
state_dict = kwargs.delete(:state_dict)
|
215
|
+
from_tf = kwargs.delete(:from_tf) { false }
|
216
|
+
from_flax = kwargs.delete(:from_flax) { false }
|
217
|
+
resume_download = kwargs.delete(:resume_download) { false }
|
218
|
+
proxies = kwargs.delete(:proxies)
|
219
|
+
output_loading_info = kwargs.delete(:output_loading_info) { false }
|
220
|
+
_use_auth_token = kwargs.delete(:use_auth_token)
|
221
|
+
trust_remote_code = kwargs.delete(:trust_remote_code)
|
222
|
+
_ = kwargs.delete(:mirror)
|
223
|
+
from_pipeline = kwargs.delete(:_from_pipeline)
|
224
|
+
from_auto_class = kwargs.delete(:_from_auto) { false }
|
225
|
+
_fast_init = kwargs.delete(:_fast_init) { true }
|
226
|
+
torch_dtype = kwargs.delete(:torch_dtype)
|
227
|
+
low_cpu_mem_usage = kwargs.delete(:low_cpu_mem_usage)
|
228
|
+
device_map = kwargs.delete(:device_map)
|
229
|
+
_max_memory = kwargs.delete(:max_memory)
|
230
|
+
offload_folder = kwargs.delete(:offload_folder)
|
231
|
+
offload_state_dict = kwargs.delete(:offload_state_dict) { false }
|
232
|
+
load_in_8bit = kwargs.delete(:load_in_8bit) { false }
|
233
|
+
load_in_4bit = kwargs.delete(:load_in_4bit) { false }
|
234
|
+
quantization_config = kwargs.delete(:quantization_config)
|
235
|
+
subfolder = kwargs.delete(:subfolder) { "" }
|
236
|
+
commit_hash = kwargs.delete(:_commit_hash)
|
237
|
+
variant = kwargs.delete(:variant)
|
238
|
+
_adapter_kwargs = kwargs.delete(:adapter_kwargs) { {} }
|
239
|
+
_adapter_name = kwargs.delete(:adapter_name) { "default" }
|
240
|
+
_use_flash_attention_2 = kwargs.delete(:use_flash_attention_2) { false }
|
241
|
+
|
242
|
+
if use_safetensors.nil? && !is_safetensors_available
|
243
|
+
use_safetensors = false
|
244
|
+
end
|
245
|
+
if trust_remote_code
|
246
|
+
Transformers.logger.warn(
|
247
|
+
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" +
|
248
|
+
" ignored."
|
249
|
+
)
|
250
|
+
end
|
251
|
+
|
252
|
+
if commit_hash.nil?
|
253
|
+
if !config.is_a?(PretrainedConfig)
|
254
|
+
# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
|
255
|
+
resolved_config_file =
|
256
|
+
Utils::Hub.cached_file(
|
257
|
+
pretrained_model_name_or_path,
|
258
|
+
CONFIG_NAME,
|
259
|
+
cache_dir: cache_dir,
|
260
|
+
force_download: force_download,
|
261
|
+
resume_download: resume_download,
|
262
|
+
proxies: proxies,
|
263
|
+
local_files_only: local_files_only,
|
264
|
+
token: token,
|
265
|
+
revision: revision,
|
266
|
+
subfolder: subfolder,
|
267
|
+
_raise_exceptions_for_gated_repo: false,
|
268
|
+
_raise_exceptions_for_missing_entries: false,
|
269
|
+
_raise_exceptions_for_connection_errors: false,
|
270
|
+
)
|
271
|
+
commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
|
272
|
+
else
|
273
|
+
commit_hash = config._commit_hash
|
274
|
+
end
|
275
|
+
end
|
276
|
+
|
277
|
+
if !device_map.nil?
|
278
|
+
raise Todo
|
279
|
+
end
|
280
|
+
|
281
|
+
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
|
282
|
+
if load_in_4bit || load_in_8bit
|
283
|
+
raise Todo
|
284
|
+
end
|
285
|
+
|
286
|
+
from_pt = !(from_tf || from_flax)
|
287
|
+
|
288
|
+
user_agent = {file_type: "model", framework: "pytorch", from_auto_class: from_auto_class}
|
289
|
+
if !from_pipeline.nil?
|
290
|
+
user_agent[:using_pipeline] = from_pipeline
|
291
|
+
end
|
292
|
+
|
293
|
+
if Utils::Hub.is_offline_mode && !local_files_only
|
294
|
+
Transformers.logger.info "Offline mode: forcing local_files_only: true"
|
295
|
+
local_files_only = true
|
296
|
+
end
|
297
|
+
|
298
|
+
# Load config if we don't provide a configuration
|
299
|
+
if !config.is_a?(PretrainedConfig)
|
300
|
+
config_path = !config.nil? ? config : pretrained_model_name_or_path
|
301
|
+
config, model_kwargs =
|
302
|
+
config_class.from_pretrained(
|
303
|
+
config_path,
|
304
|
+
cache_dir: cache_dir,
|
305
|
+
return_unused_kwargs: true,
|
306
|
+
force_download: force_download,
|
307
|
+
resume_download: resume_download,
|
308
|
+
proxies: proxies,
|
309
|
+
local_files_only: local_files_only,
|
310
|
+
token: token,
|
311
|
+
revision: revision,
|
312
|
+
subfolder: subfolder,
|
313
|
+
_from_auto: from_auto_class,
|
314
|
+
_from_pipeline: from_pipeline,
|
315
|
+
**kwargs
|
316
|
+
)
|
317
|
+
else
|
318
|
+
# In case one passes a config to `from_pretrained` + "attn_implementation"
|
319
|
+
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
|
320
|
+
# Please see: https://github.com/huggingface/transformers/issues/28038
|
321
|
+
|
322
|
+
# Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
|
323
|
+
# we pop attn_implementation from the kwargs but this handles the case where users
|
324
|
+
# passes manually the config to `from_pretrained`.
|
325
|
+
config = Copy.deepcopy(config)
|
326
|
+
|
327
|
+
kwarg_attn_imp = kwargs.delete(:attn_implementation)
|
328
|
+
if !kwarg_attn_imp.nil? && config._attn_implementation != kwarg_attn_imp
|
329
|
+
config._attn_implementation = kwarg_attn_imp
|
330
|
+
end
|
331
|
+
model_kwargs = kwargs
|
332
|
+
end
|
333
|
+
|
334
|
+
pre_quantized = false # !config.quantization_config.nil?
|
335
|
+
if pre_quantized || !quantization_config.nil?
|
336
|
+
raise Todo
|
337
|
+
else
|
338
|
+
hf_quantizer = nil
|
339
|
+
end
|
340
|
+
|
341
|
+
if !hf_quantizer.nil?
|
342
|
+
raise Todo
|
343
|
+
end
|
344
|
+
|
345
|
+
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
346
|
+
# index of the files.
|
347
|
+
is_sharded = false
|
348
|
+
sharded_metadata = nil
|
349
|
+
# Load model
|
350
|
+
_loading_info = nil
|
351
|
+
|
352
|
+
# Keep in fp32 modules
|
353
|
+
keep_in_fp32_modules = nil
|
354
|
+
_use_keep_in_fp32_modules = false
|
355
|
+
|
356
|
+
resolved_archive_file = nil
|
357
|
+
if !pretrained_model_name_or_path.nil?
|
358
|
+
pretrained_model_name_or_path = pretrained_model_name_or_path.to_s
|
359
|
+
is_local = Dir.exist?(pretrained_model_name_or_path)
|
360
|
+
if is_local
|
361
|
+
raise Todo
|
362
|
+
elsif File.exist?(File.join(subfolder, pretrained_model_name_or_path))
|
363
|
+
_archive_file = pretrained_model_name_or_path
|
364
|
+
is_local = true
|
365
|
+
else
|
366
|
+
# set correct filename
|
367
|
+
if use_safetensors != false
|
368
|
+
filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
|
369
|
+
else
|
370
|
+
filename = _add_variant(WEIGHTS_NAME, variant)
|
371
|
+
end
|
372
|
+
|
373
|
+
# Load from URL or cache if already cached
|
374
|
+
cached_file_kwargs = {
|
375
|
+
cache_dir: cache_dir,
|
376
|
+
force_download: force_download,
|
377
|
+
proxies: proxies,
|
378
|
+
resume_download: resume_download,
|
379
|
+
local_files_only: local_files_only,
|
380
|
+
token: token,
|
381
|
+
user_agent: user_agent,
|
382
|
+
revision: revision,
|
383
|
+
subfolder: subfolder,
|
384
|
+
_raise_exceptions_for_gated_repo: false,
|
385
|
+
_raise_exceptions_for_missing_entries: false,
|
386
|
+
_commit_hash: commit_hash
|
387
|
+
}
|
388
|
+
resolved_archive_file = Utils::Hub.cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
389
|
+
|
390
|
+
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
391
|
+
# result when internet is up, the repo and revision exist, but the file does not.
|
392
|
+
if resolved_archive_file.nil? && filename == _add_variant(SAFE_WEIGHTS_NAME, variant)
|
393
|
+
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
394
|
+
resolved_archive_file = Utils::Hub.cached_file(
|
395
|
+
pretrained_model_name_or_path,
|
396
|
+
_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
|
397
|
+
**cached_file_kwargs,
|
398
|
+
)
|
399
|
+
if !resolved_archive_file.nil?
|
400
|
+
is_sharded = true
|
401
|
+
elsif use_safetensors
|
402
|
+
raise Todo
|
403
|
+
else
|
404
|
+
# This repo has no safetensors file of any kind, we switch to PyTorch.
|
405
|
+
filename = _add_variant(WEIGHTS_NAME, variant)
|
406
|
+
resolved_archive_file = Utils::Hub.cached_file(
|
407
|
+
pretrained_model_name_or_path, filename, **cached_file_kwargs
|
408
|
+
)
|
409
|
+
end
|
410
|
+
end
|
411
|
+
if resolved_archive_file.nil? && filename == _add_variant(WEIGHTS_NAME, variant)
|
412
|
+
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
413
|
+
resolved_archive_file = Utils::Hub.cached_file(
|
414
|
+
pretrained_model_name_or_path,
|
415
|
+
_add_variant(WEIGHTS_INDEX_NAME, variant),
|
416
|
+
**cached_file_kwargs
|
417
|
+
)
|
418
|
+
if !resolved_archive_file.nil?
|
419
|
+
is_sharded = true
|
420
|
+
end
|
421
|
+
end
|
422
|
+
if !local_files_only && !Utils::Hub.is_offline_mode
|
423
|
+
if !resolved_archive_file.nil?
|
424
|
+
if [WEIGHTS_NAME, WEIGHTS_INDEX_NAME].include?(filename)
|
425
|
+
# If the PyTorch file was found, check if there is a safetensors file on the repository
|
426
|
+
# If there is no safetensors file on the repositories, start an auto conversion
|
427
|
+
_safe_weights_name = is_sharded ? SAFE_WEIGHTS_INDEX_NAME : SAFE_WEIGHTS_NAME
|
428
|
+
has_file_kwargs = {
|
429
|
+
revision: revision,
|
430
|
+
proxies: proxies,
|
431
|
+
token: token,
|
432
|
+
cache_dir: cache_dir,
|
433
|
+
local_files_only: local_files_only
|
434
|
+
}
|
435
|
+
cached_file_kwargs = {
|
436
|
+
cache_dir: cache_dir,
|
437
|
+
force_download: force_download,
|
438
|
+
resume_download: resume_download,
|
439
|
+
local_files_only: local_files_only,
|
440
|
+
user_agent: user_agent,
|
441
|
+
subfolder: subfolder,
|
442
|
+
_raise_exceptions_for_gated_repo: false,
|
443
|
+
_raise_exceptions_for_missing_entries: false,
|
444
|
+
_commit_hash: commit_hash,
|
445
|
+
**has_file_kwargs
|
446
|
+
}
|
447
|
+
# skip auto conversion
|
448
|
+
# if !Utils::Hub.has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
|
449
|
+
# end
|
450
|
+
end
|
451
|
+
else
|
452
|
+
raise Todo
|
453
|
+
end
|
454
|
+
end
|
455
|
+
|
456
|
+
if is_local
|
457
|
+
Transformers.logger.info("loading weights file #{archive_file}")
|
458
|
+
resolved_archive_file = archive_file
|
459
|
+
else
|
460
|
+
Transformers.logger.info("loading weights file #{filename} from cache at #{resolved_archive_file}")
|
461
|
+
end
|
462
|
+
end
|
463
|
+
else
|
464
|
+
resolved_archive_file = nil
|
465
|
+
end
|
466
|
+
|
467
|
+
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
468
|
+
if is_sharded
|
469
|
+
raise Todo
|
470
|
+
end
|
471
|
+
|
472
|
+
metadata = nil
|
473
|
+
if is_safetensors_available && resolved_archive_file.is_a?(String) && resolved_archive_file.end_with?(".safetensors")
|
474
|
+
Safetensors.safe_open(resolved_archive_file, framework: "pt") do |f|
|
475
|
+
metadata = f.metadata
|
476
|
+
end
|
477
|
+
|
478
|
+
if metadata["format"] == "pt"
|
479
|
+
# do nothing
|
480
|
+
else
|
481
|
+
raise ArgumentError,
|
482
|
+
"Incompatible safetensors file. File metadata is not ['pt'] but #{metadata["format"]}"
|
483
|
+
end
|
484
|
+
end
|
485
|
+
|
486
|
+
from_pt = !(from_tf || from_flax)
|
487
|
+
|
488
|
+
# load pt weights early so that we know which dtype to init the model under
|
489
|
+
if from_pt
|
490
|
+
if !is_sharded && state_dict.nil?
|
491
|
+
# Time to load the checkpoint
|
492
|
+
state_dict = load_state_dict(resolved_archive_file)
|
493
|
+
end
|
494
|
+
|
495
|
+
# set dtype to instantiate the model under:
|
496
|
+
# 1. If torch_dtype is not None, we use that dtype
|
497
|
+
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
498
|
+
# weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
|
499
|
+
# we also may have config.torch_dtype available, but we won't rely on it till v5
|
500
|
+
dtype_orig = nil
|
501
|
+
|
502
|
+
if !torch_dtype.nil?
|
503
|
+
raise Todo
|
504
|
+
end
|
505
|
+
|
506
|
+
if is_sharded
|
507
|
+
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
508
|
+
else
|
509
|
+
loaded_state_dict_keys = state_dict.keys
|
510
|
+
end
|
511
|
+
end
|
512
|
+
|
513
|
+
config.name_or_path = pretrained_model_name_or_path
|
514
|
+
|
515
|
+
model_kwargs = {}
|
516
|
+
model = new(config, *model_args, **model_kwargs)
|
517
|
+
|
518
|
+
# make sure we use the model's config since the __init__ call might have copied it
|
519
|
+
config = model.config
|
520
|
+
|
521
|
+
if device_map.is_a?(String)
|
522
|
+
raise Todo
|
523
|
+
elsif !device_map.nil?
|
524
|
+
raise Todo
|
525
|
+
end
|
526
|
+
|
527
|
+
if from_pt
|
528
|
+
# restore default dtype
|
529
|
+
if !dtype_orig.nil?
|
530
|
+
Torch.set_default_dtype(dtype_orig)
|
531
|
+
end
|
532
|
+
|
533
|
+
model, _missing_keys, _unexpected_keys, _mismatched_keys, _offload_index, _error_msgs =
|
534
|
+
_load_pretrained_model(
|
535
|
+
model,
|
536
|
+
state_dict,
|
537
|
+
loaded_state_dict_keys, # XXX: rename?
|
538
|
+
resolved_archive_file,
|
539
|
+
pretrained_model_name_or_path,
|
540
|
+
ignore_mismatched_sizes: ignore_mismatched_sizes,
|
541
|
+
sharded_metadata: sharded_metadata,
|
542
|
+
_fast_init: _fast_init,
|
543
|
+
low_cpu_mem_usage: low_cpu_mem_usage,
|
544
|
+
device_map: device_map,
|
545
|
+
offload_folder: offload_folder,
|
546
|
+
offload_state_dict: offload_state_dict,
|
547
|
+
dtype: torch_dtype,
|
548
|
+
hf_quantizer: hf_quantizer,
|
549
|
+
keep_in_fp32_modules: keep_in_fp32_modules
|
550
|
+
)
|
551
|
+
end
|
552
|
+
|
553
|
+
# make sure token embedding weights are still tied if needed
|
554
|
+
model.tie_weights
|
555
|
+
|
556
|
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
557
|
+
model.eval
|
558
|
+
|
559
|
+
# If it is a model with generation capabilities, attempt to load the generation config
|
560
|
+
if model.can_generate && !pretrained_model_name_or_path.nil?
|
561
|
+
raise Todo
|
562
|
+
end
|
563
|
+
|
564
|
+
# Dispatch model with hooks on all devices if necessary
|
565
|
+
if !device_map.nil?
|
566
|
+
raise Todo
|
567
|
+
end
|
568
|
+
|
569
|
+
if !hf_quantizer.nil?
|
570
|
+
raise Todo
|
571
|
+
end
|
572
|
+
|
573
|
+
if output_loading_info
|
574
|
+
raise Todo
|
575
|
+
end
|
576
|
+
|
577
|
+
model
|
578
|
+
end
|
579
|
+
|
580
|
+
private
|
581
|
+
|
582
|
+
def _load_pretrained_model(
|
583
|
+
model,
|
584
|
+
state_dict,
|
585
|
+
loaded_keys,
|
586
|
+
resolved_archive_file,
|
587
|
+
pretrained_model_name_or_path,
|
588
|
+
ignore_mismatched_sizes: false,
|
589
|
+
sharded_metadata: nil,
|
590
|
+
_fast_init: true,
|
591
|
+
low_cpu_mem_usage: false,
|
592
|
+
device_map: nil,
|
593
|
+
offload_folder: nil,
|
594
|
+
offload_state_dict: nil,
|
595
|
+
dtype: nil,
|
596
|
+
hf_quantizer: nil,
|
597
|
+
keep_in_fp32_modules: nil
|
598
|
+
)
|
599
|
+
is_safetensors = false
|
600
|
+
|
601
|
+
_is_sharded_safetensors = is_safetensors && !sharded_metadata.nil?
|
602
|
+
|
603
|
+
# tie the model weights before retrieving the state_dict
|
604
|
+
model.tie_weights
|
605
|
+
|
606
|
+
# Retrieve missing & unexpected_keys
|
607
|
+
model_state_dict = model.state_dict
|
608
|
+
expected_keys = model_state_dict.keys
|
609
|
+
prefix = model.class.base_model_prefix
|
610
|
+
|
611
|
+
_fix_key = lambda do |key|
|
612
|
+
if key.include?("beta")
|
613
|
+
key.gsub("beta", "bias")
|
614
|
+
end
|
615
|
+
if key.include?("gamma")
|
616
|
+
key.gsub("gamma", "weight")
|
617
|
+
else
|
618
|
+
key
|
619
|
+
end
|
620
|
+
end
|
621
|
+
|
622
|
+
original_loaded_keys = loaded_keys
|
623
|
+
loaded_keys = loaded_keys.map { |key| _fix_key.(key) }
|
624
|
+
|
625
|
+
if prefix.length > 0
|
626
|
+
has_prefix_module = loaded_keys.any? { |s| s.start_with?(prefix) }
|
627
|
+
expects_prefix_module = expected_keys.any? { |s| s.start_with?(prefix) }
|
628
|
+
else
|
629
|
+
has_prefix_module = false
|
630
|
+
expects_prefix_module = false
|
631
|
+
end
|
632
|
+
|
633
|
+
# key re-naming operations are never done on the keys
|
634
|
+
# that are loaded, but always on the keys of the newly initialized model
|
635
|
+
remove_prefix_from_model = !has_prefix_module && expects_prefix_module
|
636
|
+
add_prefix_to_model = has_prefix_module && !expects_prefix_module
|
637
|
+
|
638
|
+
if remove_prefix_from_model
|
639
|
+
_prefix = "#{prefix}."
|
640
|
+
expected_keys_not_prefixed = expected_keys.select { |s| !s.start_with?(_prefix) }
|
641
|
+
expected_keys = expected_keys.map { |s| s.start_with?(_prefix) ? s[_prefix.length..] : s }
|
642
|
+
elsif add_prefix_to_model
|
643
|
+
expected_keys = expected_keys.map { |s| [prefix, s].join(".") }
|
644
|
+
end
|
645
|
+
|
646
|
+
missing_keys = (Set.new(expected_keys) - Set.new(loaded_keys)).sort
|
647
|
+
unexpected_keys = Set.new(loaded_keys) - Set.new(expected_keys)
|
648
|
+
# Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
|
649
|
+
# buffers
|
650
|
+
model_buffers = model.named_buffers(recurse: true).keys
|
651
|
+
if remove_prefix_from_model
|
652
|
+
raise Todo
|
653
|
+
elsif add_prefix_to_model
|
654
|
+
model_buffers = model_buffers.map { |key| [prefix, key].join(".") }
|
655
|
+
end
|
656
|
+
unexpected_keys = (unexpected_keys - model_buffers).sort
|
657
|
+
|
658
|
+
model.tie_weights
|
659
|
+
if device_map.nil?
|
660
|
+
ptrs = Hash.new { |hash, key| hash[key] = [] }
|
661
|
+
|
662
|
+
model.state_dict.each do |name, tensor|
|
663
|
+
# TODO fix
|
664
|
+
id_tensor = tensor.object_id # id_tensor_storage(tensor)
|
665
|
+
ptrs[id_tensor] << name
|
666
|
+
end
|
667
|
+
|
668
|
+
# These are all the pointers of shared tensors.
|
669
|
+
tied_params = ptrs.select { |_, names| names.length > 1 }.values
|
670
|
+
else
|
671
|
+
raise Todo
|
672
|
+
end
|
673
|
+
|
674
|
+
tied_params.each do |group|
|
675
|
+
if remove_prefix_from_model
|
676
|
+
group = group.map { |key| key.delete_prefix(_prefix) }
|
677
|
+
elsif add_prefix_to_model
|
678
|
+
group = group.map { |key| [prefix, key].join(".") }
|
679
|
+
end
|
680
|
+
missing_in_group = missing_keys.select { |k| group.include?(k) }
|
681
|
+
if missing_in_group.length > 0 && missing_in_group.length < group.length
|
682
|
+
missing_keys = missing_keys.select { |k| !missing_in_group.include?(k) }
|
683
|
+
end
|
684
|
+
end
|
685
|
+
|
686
|
+
# Make sure we are able to load base models as well as derived models (with heads)
|
687
|
+
start_prefix = ""
|
688
|
+
model_to_load = model
|
689
|
+
if base_model_prefix.length > 0 && !model.instance_variable_defined?("@#{base_model_prefix}") && has_prefix_module
|
690
|
+
start_prefix = base_model_prefix + "."
|
691
|
+
end
|
692
|
+
if base_model_prefix.length > 0 && model.instance_variable_defined?("@#{base_model_prefix}") && !has_prefix_module
|
693
|
+
model_to_load = model.instance_variable_get("@#{base_model_prefix}")
|
694
|
+
base_model_expected_keys = model_to_load.state_dict.keys
|
695
|
+
if loaded_keys.any? { |key| expected_keys_not_prefixed.include?(key) && !base_model_expected_keys.include?(key) }
|
696
|
+
raise ArgumentError, "The state dictionary of the model you are trying to load is corrupted. Are you sure it was properly saved?"
|
697
|
+
end
|
698
|
+
if !device_map.nil?
|
699
|
+
raise Todo
|
700
|
+
end
|
701
|
+
end
|
702
|
+
|
703
|
+
_find_mismatched_keys = lambda do |state_dict, model_state_dict, loaded_keys, add_prefix_to_model, remove_prefix_from_model, ignore_mismatched_sizes|
|
704
|
+
mismatched_keys = []
|
705
|
+
if ignore_mismatched_sizes
|
706
|
+
loaded_keys.each do |checkpoint_key|
|
707
|
+
# If the checkpoint is sharded, we may not have the key here.
|
708
|
+
if !state_dict.include?(checkpoint_key)
|
709
|
+
next
|
710
|
+
end
|
711
|
+
model_key = checkpoint_key
|
712
|
+
if remove_prefix_from_model
|
713
|
+
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
|
714
|
+
model_key = "#{prefix}.#{checkpoint_key}"
|
715
|
+
elsif add_prefix_to_model
|
716
|
+
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
|
717
|
+
model_key = checkpoint_key.split(".")[1..].join(".")
|
718
|
+
end
|
719
|
+
|
720
|
+
if model_state_dict.include?(model_key) && state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
721
|
+
raise Todo
|
722
|
+
end
|
723
|
+
end
|
724
|
+
end
|
725
|
+
mismatched_keys
|
726
|
+
end
|
727
|
+
|
728
|
+
if !resolved_archive_file.nil?
|
729
|
+
_folder = File.dirname(resolved_archive_file)
|
730
|
+
else
|
731
|
+
_folder = nil
|
732
|
+
end
|
733
|
+
|
734
|
+
if !device_map.nil? && is_safetensors
|
735
|
+
raise Todo
|
736
|
+
end
|
737
|
+
|
738
|
+
if !state_dict.nil?
|
739
|
+
# Whole checkpoint
|
740
|
+
mismatched_keys =
|
741
|
+
_find_mismatched_keys.(
|
742
|
+
state_dict,
|
743
|
+
model_state_dict,
|
744
|
+
original_loaded_keys,
|
745
|
+
add_prefix_to_model,
|
746
|
+
remove_prefix_from_model,
|
747
|
+
ignore_mismatched_sizes
|
748
|
+
)
|
749
|
+
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
750
|
+
offload_index = nil
|
751
|
+
else
|
752
|
+
raise Todo
|
753
|
+
end
|
754
|
+
|
755
|
+
if error_msgs.length > 0
|
756
|
+
raise Todo
|
757
|
+
end
|
758
|
+
|
759
|
+
if unexpected_keys.length > 0
|
760
|
+
archs = model.config.architectures.nil? ? [] : model.config.architectures
|
761
|
+
warner = archs.include?(model.class.name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
|
762
|
+
warner.(
|
763
|
+
"Some weights of the model checkpoint at #{pretrained_model_name_or_path} were not used when" +
|
764
|
+
" initializing #{model.class.name}: #{unexpected_keys}\n- This IS expected if you are" +
|
765
|
+
" initializing #{model.class.name} from the checkpoint of a model trained on another task or" +
|
766
|
+
" with another architecture (e.g. initializing a BertForSequenceClassification model from a" +
|
767
|
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing" +
|
768
|
+
" #{model.class.name} from the checkpoint of a model that you expect to be exactly identical" +
|
769
|
+
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
770
|
+
)
|
771
|
+
else
|
772
|
+
Transformers.logger.info("All model checkpoint weights were used when initializing #{model.class.name}.\n")
|
773
|
+
end
|
774
|
+
if missing_keys.length > 0
|
775
|
+
Transformers.logger.info("Some weights of #{model.class.name} were not initialized from the model checkpoint at #{pretrained_model_name_or_path} and are newly initialized: #{missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.")
|
776
|
+
elsif mismatched_keys.length == 0
|
777
|
+
Transformers.logger.info(
|
778
|
+
"All the weights of #{model.class.name} were initialized from the model checkpoint at" +
|
779
|
+
" #{pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" +
|
780
|
+
" was trained on, you can already use #{model.class.name} for predictions without further" +
|
781
|
+
" training."
|
782
|
+
)
|
783
|
+
end
|
784
|
+
if mismatched_keys.length > 0
|
785
|
+
raise Todo
|
786
|
+
end
|
787
|
+
|
788
|
+
[model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs]
|
789
|
+
end
|
790
|
+
|
791
|
+
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
792
|
+
# Convert old format to new format if needed from a PyTorch state_dict
|
793
|
+
old_keys = []
|
794
|
+
new_keys = []
|
795
|
+
state_dict.each_key do |key|
|
796
|
+
new_key = nil
|
797
|
+
if key.include?("gamma")
|
798
|
+
new_key = key.gsub("gamma", "weight")
|
799
|
+
end
|
800
|
+
if key.include?("beta")
|
801
|
+
new_key = key.gsub("beta", "bias")
|
802
|
+
end
|
803
|
+
if new_key
|
804
|
+
old_keys << key
|
805
|
+
new_keys << new_key
|
806
|
+
end
|
807
|
+
end
|
808
|
+
old_keys.zip(new_keys) do |old_key, new_key|
|
809
|
+
state_dict[new_key] = state_dict.delete(old_key)
|
810
|
+
end
|
811
|
+
|
812
|
+
# copy state_dict so _load_from_state_dict can modify it
|
813
|
+
metadata = nil #getattr(state_dict, "_metadata", None)
|
814
|
+
state_dict = state_dict.dup
|
815
|
+
if !metadata.nil?
|
816
|
+
state_dict._metadata = metadata
|
817
|
+
end
|
818
|
+
|
819
|
+
error_msgs = []
|
820
|
+
|
821
|
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
822
|
+
# so we need to apply the function recursively.
|
823
|
+
load = lambda do |mod, state_dict, prefix|
|
824
|
+
local_metadata = metadata.nil? ? {} : metadata.fetch(prefix[...-1], {})
|
825
|
+
args = [state_dict, prefix, local_metadata, true, [], [], error_msgs]
|
826
|
+
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
827
|
+
# state_dict
|
828
|
+
if state_dict.any? { |key, _| key.start_with?(prefix) }
|
829
|
+
mod.send(:load_from_state_dict, *args)
|
830
|
+
end
|
831
|
+
|
832
|
+
mod.named_children.each do |name, child|
|
833
|
+
if !child.nil?
|
834
|
+
load.(child, state_dict, prefix + name + ".")
|
835
|
+
end
|
836
|
+
end
|
837
|
+
end
|
838
|
+
|
839
|
+
load.(model_to_load, state_dict, start_prefix)
|
840
|
+
|
841
|
+
error_msgs
|
842
|
+
end
|
843
|
+
|
844
|
+
def is_safetensors_available
|
845
|
+
defined?(Safetensors)
|
846
|
+
end
|
847
|
+
|
848
|
+
def load_state_dict(checkpoint_file)
|
849
|
+
if checkpoint_file.end_with?(".safetensors") && is_safetensors_available
|
850
|
+
# Check format of the archive
|
851
|
+
metadata = nil
|
852
|
+
Safetensors.safe_open(checkpoint_file, framework: "pt") do |f|
|
853
|
+
metadata = f.metadata
|
854
|
+
end
|
855
|
+
if !["pt", "tf", "flax"].include?(metadata["format"])
|
856
|
+
raise OSError, "The safetensors archive passed at #{checkpoint_file} does not contain the valid metadata. Make sure you save your model with the `save_pretrained` method."
|
857
|
+
end
|
858
|
+
return Safetensors::Torch.load_file(checkpoint_file)
|
859
|
+
end
|
860
|
+
begin
|
861
|
+
_map_location = "cpu"
|
862
|
+
_extra_args = {}
|
863
|
+
_weights_only_kwarg = {weights_only: true}
|
864
|
+
Torch.load(
|
865
|
+
checkpoint_file,
|
866
|
+
# Torch.rb does not currently support additional options
|
867
|
+
# map_location: map_location,
|
868
|
+
# **weights_only_kwarg,
|
869
|
+
# **extra_args
|
870
|
+
)
|
871
|
+
rescue => e
|
872
|
+
# TODO improve
|
873
|
+
raise e
|
874
|
+
end
|
875
|
+
end
|
876
|
+
|
877
|
+
def _add_variant(weights_name, variant)
|
878
|
+
if !variant.nil?
|
879
|
+
splits = weights_name.split(".")
|
880
|
+
splits = splits[...-1] + [variant] + splits[-1..]
|
881
|
+
weights_name = splits.join(".")
|
882
|
+
end
|
883
|
+
|
884
|
+
weights_name
|
885
|
+
end
|
886
|
+
end
|
887
|
+
end
|
888
|
+
end
|