transformers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,764 @@
|
|
1
|
+
module Transformers
|
2
|
+
module HfHub
|
3
|
+
# Return value when trying to load a file from cache but the file does not exist in the distant repo.
|
4
|
+
CACHED_NO_EXIST = Object.new
|
5
|
+
|
6
|
+
# Regex to get filename from a "Content-Disposition" header for CDN-served files
|
7
|
+
HEADER_FILENAME_PATTERN = /filename="(?<filename>.*?)";/
|
8
|
+
|
9
|
+
# Regex to check if the revision IS directly a commit_hash
|
10
|
+
REGEX_COMMIT_HASH = /^[0-9a-f]{40}$/
|
11
|
+
|
12
|
+
class HfFileMetadata
|
13
|
+
attr_reader :commit_hash, :etag, :location, :size
|
14
|
+
|
15
|
+
def initialize(commit_hash:, etag:, location:, size:)
|
16
|
+
@commit_hash = commit_hash
|
17
|
+
@etag = etag
|
18
|
+
@location = location
|
19
|
+
@size = size
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
class << self
|
24
|
+
def hf_hub_url(
|
25
|
+
repo_id,
|
26
|
+
filename,
|
27
|
+
subfolder: nil,
|
28
|
+
repo_type: nil,
|
29
|
+
revision: nil,
|
30
|
+
endpoint: nil
|
31
|
+
)
|
32
|
+
if subfolder == ""
|
33
|
+
subfolder = nil
|
34
|
+
end
|
35
|
+
if !subfolder.nil?
|
36
|
+
filename = "#{subfolder}/#{filename}"
|
37
|
+
end
|
38
|
+
|
39
|
+
if !REPO_TYPES.include?(repo_type)
|
40
|
+
raise ArgumentError, "Invalid repo type"
|
41
|
+
end
|
42
|
+
|
43
|
+
if REPO_TYPES_URL_PREFIXES.include?(repo_type)
|
44
|
+
repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
|
45
|
+
end
|
46
|
+
|
47
|
+
if revision.nil?
|
48
|
+
revision = DEFAULT_REVISION
|
49
|
+
end
|
50
|
+
url =
|
51
|
+
HUGGINGFACE_CO_URL_TEMPLATE %
|
52
|
+
{repo_id: repo_id, revision: CGI.escape(revision), filename: CGI.escape(filename)}
|
53
|
+
# Update endpoint if provided
|
54
|
+
if !endpoint.nil? && url.start_with?(ENDPOINT)
|
55
|
+
url = endpoint + url[ENDPOINT.length..]
|
56
|
+
end
|
57
|
+
url
|
58
|
+
end
|
59
|
+
|
60
|
+
def _request_wrapper(method, url, follow_relative_redirects: false, redirects: 0, **params)
|
61
|
+
# Recursively follow relative redirects
|
62
|
+
if follow_relative_redirects
|
63
|
+
if redirects > 10
|
64
|
+
raise "Too many redirects"
|
65
|
+
end
|
66
|
+
|
67
|
+
response = _request_wrapper(
|
68
|
+
method,
|
69
|
+
url,
|
70
|
+
follow_relative_redirects: false,
|
71
|
+
**params
|
72
|
+
)
|
73
|
+
|
74
|
+
# If redirection, we redirect only relative paths.
|
75
|
+
# This is useful in case of a renamed repository.
|
76
|
+
if response.is_a?(Net::HTTPRedirection)
|
77
|
+
parsed_target = URI.parse(response["Location"])
|
78
|
+
if netloc(parsed_target) == ""
|
79
|
+
# This means it is a relative 'location' headers, as allowed by RFC 7231.
|
80
|
+
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
|
81
|
+
# We want to follow this relative redirect !
|
82
|
+
#
|
83
|
+
# Highly inspired by `resolve_redirects` from requests library.
|
84
|
+
# See https://github.com/psf/requests/blob/main/requests/sessions.py#L159
|
85
|
+
next_url = URI.parse(url)
|
86
|
+
next_url.path = parsed_target.path
|
87
|
+
return _request_wrapper(method, next_url, follow_relative_redirects: true, redirects: redirects + 1, **params)
|
88
|
+
end
|
89
|
+
end
|
90
|
+
return response
|
91
|
+
end
|
92
|
+
|
93
|
+
# Perform request and return if status_code is not in the retry list.
|
94
|
+
uri = URI.parse(url)
|
95
|
+
|
96
|
+
http_options = {use_ssl: true}
|
97
|
+
if params[:timeout]
|
98
|
+
http_options[:open_timeout] = params[:timeout]
|
99
|
+
http_options[:read_timeout] = params[:timeout]
|
100
|
+
http_options[:write_timeout] = params[:timeout]
|
101
|
+
end
|
102
|
+
response =
|
103
|
+
Net::HTTP.start(uri.host, uri.port, **http_options) do |http|
|
104
|
+
http.send_request(method, uri.path, nil, params[:headers])
|
105
|
+
end
|
106
|
+
response.uri ||= uri
|
107
|
+
hf_raise_for_status(response)
|
108
|
+
response
|
109
|
+
end
|
110
|
+
|
111
|
+
def http_get(
|
112
|
+
url,
|
113
|
+
temp_file,
|
114
|
+
proxies: nil,
|
115
|
+
resume_size: 0,
|
116
|
+
headers: nil,
|
117
|
+
expected_size: nil,
|
118
|
+
displayed_filename: nil,
|
119
|
+
_nb_retries: 5
|
120
|
+
)
|
121
|
+
uri = URI.parse(url)
|
122
|
+
|
123
|
+
if resume_size > 0
|
124
|
+
headers["range"] = "bytes=%d-" % [resume_size]
|
125
|
+
end
|
126
|
+
|
127
|
+
size = resume_size
|
128
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http|
|
129
|
+
request = Net::HTTP::Get.new(uri)
|
130
|
+
headers.each do |k, v|
|
131
|
+
request[k] = v
|
132
|
+
end
|
133
|
+
http.request(request) do |response|
|
134
|
+
case response
|
135
|
+
when Net::HTTPSuccess
|
136
|
+
if displayed_filename.nil?
|
137
|
+
displayed_filename = url
|
138
|
+
content_disposition = response["content-disposition"]
|
139
|
+
if !content_disposition.nil?
|
140
|
+
match = HEADER_FILENAME_PATTERN.match(content_disposition)
|
141
|
+
if !match.nil?
|
142
|
+
# Means file is on CDN
|
143
|
+
displayed_filename = match["filename"]
|
144
|
+
end
|
145
|
+
end
|
146
|
+
end
|
147
|
+
|
148
|
+
stream = STDERR
|
149
|
+
tty = stream.tty?
|
150
|
+
width = tty ? stream.winsize[1] : 80
|
151
|
+
|
152
|
+
response.read_body do |chunk|
|
153
|
+
temp_file.write(chunk)
|
154
|
+
size += chunk.bytesize
|
155
|
+
|
156
|
+
if tty
|
157
|
+
stream.print "\r#{display_progress(displayed_filename, width, size, expected_size)}"
|
158
|
+
end
|
159
|
+
end
|
160
|
+
|
161
|
+
if tty
|
162
|
+
stream.puts
|
163
|
+
else
|
164
|
+
stream.puts display_progress(displayed_filename, width, size, expected_size)
|
165
|
+
end
|
166
|
+
else
|
167
|
+
hf_raise_for_status(response)
|
168
|
+
end
|
169
|
+
end
|
170
|
+
end
|
171
|
+
end
|
172
|
+
|
173
|
+
def _normalize_etag(etag)
|
174
|
+
if etag.nil?
|
175
|
+
return nil
|
176
|
+
end
|
177
|
+
etag.sub(/\A\W/, "").delete('"')
|
178
|
+
end
|
179
|
+
|
180
|
+
def _create_symlink(src, dst, new_blob: false)
|
181
|
+
begin
|
182
|
+
FileUtils.rm(dst)
|
183
|
+
rescue Errno::ENOENT
|
184
|
+
# do nothing
|
185
|
+
end
|
186
|
+
|
187
|
+
# abs_src = File.absolute_path(File.expand_path(src))
|
188
|
+
# abs_dst = File.absolute_path(File.expand_path(dst))
|
189
|
+
# abs_dst_folder = File.dirname(abs_dst)
|
190
|
+
|
191
|
+
FileUtils.symlink(src, dst)
|
192
|
+
end
|
193
|
+
|
194
|
+
def _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
|
195
|
+
if revision != commit_hash
|
196
|
+
ref_path = Pathname.new(storage_folder) / "refs" / revision
|
197
|
+
ref_path.parent.mkpath
|
198
|
+
if !ref_path.exist? || commit_hash != ref_path.read
|
199
|
+
# Update ref only if has been updated. Could cause useless error in case
|
200
|
+
# repo is already cached and user doesn't have write access to cache folder.
|
201
|
+
# See https://github.com/huggingface/huggingface_hub/issues/1216.
|
202
|
+
ref_path.write(commit_hash)
|
203
|
+
end
|
204
|
+
end
|
205
|
+
end
|
206
|
+
|
207
|
+
def repo_folder_name(repo_id:, repo_type:)
|
208
|
+
# remove all `/` occurrences to correctly convert repo to directory name
|
209
|
+
parts = ["#{repo_type}s"] + repo_id.split("/")
|
210
|
+
parts.join(REPO_ID_SEPARATOR)
|
211
|
+
end
|
212
|
+
|
213
|
+
def _check_disk_space(expected_size, target_dir)
|
214
|
+
# TODO
|
215
|
+
end
|
216
|
+
|
217
|
+
def hf_hub_download(
|
218
|
+
repo_id,
|
219
|
+
filename,
|
220
|
+
subfolder: nil,
|
221
|
+
repo_type: nil,
|
222
|
+
revision: nil,
|
223
|
+
library_name: nil,
|
224
|
+
library_version: nil,
|
225
|
+
cache_dir: nil,
|
226
|
+
local_dir: nil,
|
227
|
+
local_dir_use_symlinks: "auto",
|
228
|
+
user_agent: nil,
|
229
|
+
force_download: false,
|
230
|
+
force_filename: nil,
|
231
|
+
proxies: nil,
|
232
|
+
etag_timeout: DEFAULT_ETAG_TIMEOUT,
|
233
|
+
resume_download: false,
|
234
|
+
token: nil,
|
235
|
+
local_files_only: false,
|
236
|
+
legacy_cache_layout: false,
|
237
|
+
endpoint: nil
|
238
|
+
)
|
239
|
+
if cache_dir.nil?
|
240
|
+
cache_dir = HF_HUB_CACHE
|
241
|
+
end
|
242
|
+
if revision.nil?
|
243
|
+
revision = DEFAULT_REVISION
|
244
|
+
end
|
245
|
+
|
246
|
+
if subfolder == ""
|
247
|
+
subfolder = nil
|
248
|
+
end
|
249
|
+
if !subfolder.nil?
|
250
|
+
# This is used to create a URL, and not a local path, hence the forward slash.
|
251
|
+
filename = "#{subfolder}/#{filename}"
|
252
|
+
end
|
253
|
+
|
254
|
+
if repo_type.nil?
|
255
|
+
repo_type = "model"
|
256
|
+
end
|
257
|
+
if !REPO_TYPES.include?(repo_type)
|
258
|
+
raise ArgumentError, "Invalid repo type: #{repo_type}. Accepted repo types are: #{REPO_TYPES}"
|
259
|
+
end
|
260
|
+
|
261
|
+
headers =
|
262
|
+
build_hf_headers(
|
263
|
+
token: token,
|
264
|
+
library_name: library_name,
|
265
|
+
library_version: library_version,
|
266
|
+
user_agent: user_agent
|
267
|
+
)
|
268
|
+
|
269
|
+
if !local_dir.nil?
|
270
|
+
raise Todo
|
271
|
+
else
|
272
|
+
_hf_hub_download_to_cache_dir(
|
273
|
+
# Destination
|
274
|
+
cache_dir: cache_dir,
|
275
|
+
# File info
|
276
|
+
repo_id: repo_id,
|
277
|
+
filename: filename,
|
278
|
+
repo_type: repo_type,
|
279
|
+
revision: revision,
|
280
|
+
# HTTP info
|
281
|
+
endpoint: endpoint,
|
282
|
+
etag_timeout: etag_timeout,
|
283
|
+
headers: headers,
|
284
|
+
proxies: proxies,
|
285
|
+
token: token,
|
286
|
+
# Additional options
|
287
|
+
local_files_only: local_files_only,
|
288
|
+
force_download: force_download
|
289
|
+
)
|
290
|
+
end
|
291
|
+
end
|
292
|
+
|
293
|
+
def _hf_hub_download_to_cache_dir(
|
294
|
+
cache_dir:,
|
295
|
+
# File info
|
296
|
+
repo_id:,
|
297
|
+
filename:,
|
298
|
+
repo_type:,
|
299
|
+
revision:,
|
300
|
+
# HTTP info
|
301
|
+
endpoint:,
|
302
|
+
etag_timeout:,
|
303
|
+
headers:,
|
304
|
+
proxies:,
|
305
|
+
token:,
|
306
|
+
# Additional options
|
307
|
+
local_files_only:,
|
308
|
+
force_download:
|
309
|
+
)
|
310
|
+
_locks_dir = File.join(cache_dir, ".locks")
|
311
|
+
storage_folder = File.join(cache_dir, repo_folder_name(repo_id: repo_id, repo_type: repo_type))
|
312
|
+
|
313
|
+
# cross platform transcription of filename, to be used as a local file path.
|
314
|
+
relative_filename = File.join(*filename.split("/"))
|
315
|
+
|
316
|
+
# if user provides a commit_hash and they already have the file on disk, shortcut everything.
|
317
|
+
if REGEX_COMMIT_HASH.match?(revision)
|
318
|
+
pointer_path = _get_pointer_path(storage_folder, revision, relative_filename)
|
319
|
+
if File.exist?(pointer_path) && !force_download
|
320
|
+
return pointer_path
|
321
|
+
end
|
322
|
+
end
|
323
|
+
|
324
|
+
# Try to get metadata (etag, commit_hash, url, size) from the server.
|
325
|
+
# If we can't, a HEAD request error is returned.
|
326
|
+
url_to_download, etag, commit_hash, expected_size, head_call_error = _get_metadata_or_catch_error(
|
327
|
+
repo_id: repo_id,
|
328
|
+
filename: filename,
|
329
|
+
repo_type: repo_type,
|
330
|
+
revision: revision,
|
331
|
+
endpoint: endpoint,
|
332
|
+
proxies: proxies,
|
333
|
+
etag_timeout: etag_timeout,
|
334
|
+
headers: headers,
|
335
|
+
token: token,
|
336
|
+
local_files_only: local_files_only,
|
337
|
+
storage_folder: storage_folder,
|
338
|
+
relative_filename: relative_filename
|
339
|
+
)
|
340
|
+
|
341
|
+
# etag can be None for several reasons:
|
342
|
+
# 1. we passed local_files_only.
|
343
|
+
# 2. we don't have a connection
|
344
|
+
# 3. Hub is down (HTTP 500 or 504)
|
345
|
+
# 4. repo is not found -for example private or gated- and invalid/missing token sent
|
346
|
+
# 5. Hub is blocked by a firewall or proxy is not set correctly.
|
347
|
+
# => Try to get the last downloaded one from the specified revision.
|
348
|
+
#
|
349
|
+
# If the specified revision is a commit hash, look inside "snapshots".
|
350
|
+
# If the specified revision is a branch or tag, look inside "refs".
|
351
|
+
if !head_call_error.nil?
|
352
|
+
# Couldn't make a HEAD call => let's try to find a local file
|
353
|
+
if !force_download
|
354
|
+
commit_hash = nil
|
355
|
+
if REGEX_COMMIT_HASH.match(revision)
|
356
|
+
commit_hash = revision
|
357
|
+
else
|
358
|
+
ref_path = File.join(storage_folder, "refs", revision)
|
359
|
+
if File.exist?(ref_path)
|
360
|
+
commit_hash = File.read(ref_path)
|
361
|
+
end
|
362
|
+
end
|
363
|
+
|
364
|
+
# Return pointer file if exists
|
365
|
+
if !commit_hash.nil?
|
366
|
+
pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
|
367
|
+
if File.exist?(pointer_path) && !force_download
|
368
|
+
return pointer_path
|
369
|
+
end
|
370
|
+
end
|
371
|
+
end
|
372
|
+
|
373
|
+
# Otherwise, raise appropriate error
|
374
|
+
_raise_on_head_call_error(head_call_error, force_download, local_files_only)
|
375
|
+
end
|
376
|
+
|
377
|
+
# From now on, etag and commit_hash are not None.
|
378
|
+
raise "etag must have been retrieved from server" if etag.nil?
|
379
|
+
raise "commit_hash must have been retrieved from server" if commit_hash.nil?
|
380
|
+
raise "file location must have been retrieved from server" if url_to_download.nil?
|
381
|
+
raise "expected_size must have been retrieved from server" if expected_size.nil?
|
382
|
+
blob_path = File.join(storage_folder, "blobs", etag)
|
383
|
+
pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
|
384
|
+
|
385
|
+
FileUtils.mkdir_p(File.dirname(blob_path))
|
386
|
+
FileUtils.mkdir_p(File.dirname(pointer_path))
|
387
|
+
|
388
|
+
# if passed revision is not identical to commit_hash
|
389
|
+
# then revision has to be a branch name or tag name.
|
390
|
+
# In that case store a ref.
|
391
|
+
_cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
|
392
|
+
|
393
|
+
if !force_download
|
394
|
+
if File.exist?(pointer_path)
|
395
|
+
return pointer_path
|
396
|
+
end
|
397
|
+
|
398
|
+
if File.exist?(blob_path)
|
399
|
+
# we have the blob already, but not the pointer
|
400
|
+
_create_symlink(blob_path, pointer_path, new_blob: false)
|
401
|
+
return pointer_path
|
402
|
+
end
|
403
|
+
end
|
404
|
+
|
405
|
+
# Prevent parallel downloads of the same file with a lock.
|
406
|
+
# etag could be duplicated across repos,
|
407
|
+
# lock_path = File.join(locks_dir, repo_folder_name(repo_id: repo_id, repo_type: repo_type), "#{etag}.lock")
|
408
|
+
|
409
|
+
_download_to_tmp_and_move(
|
410
|
+
incomplete_path: Pathname.new(blob_path + ".incomplete"),
|
411
|
+
destination_path: Pathname.new(blob_path),
|
412
|
+
url_to_download: url_to_download,
|
413
|
+
proxies: proxies,
|
414
|
+
headers: headers,
|
415
|
+
expected_size: expected_size,
|
416
|
+
filename: filename,
|
417
|
+
force_download: force_download
|
418
|
+
)
|
419
|
+
_create_symlink(blob_path, pointer_path, new_blob: true)
|
420
|
+
|
421
|
+
pointer_path
|
422
|
+
end
|
423
|
+
|
424
|
+
def try_to_load_from_cache(
|
425
|
+
repo_id,
|
426
|
+
filename,
|
427
|
+
cache_dir: nil,
|
428
|
+
revision: nil,
|
429
|
+
repo_type: nil
|
430
|
+
)
|
431
|
+
if revision.nil?
|
432
|
+
revision = "main"
|
433
|
+
end
|
434
|
+
if repo_type.nil?
|
435
|
+
repo_type = "model"
|
436
|
+
end
|
437
|
+
if !REPO_TYPES.include?(repo_type)
|
438
|
+
raise ArgumentError, "Invalid repo type: #{repo_type}. Accepted repo types are: #{REPO_TYPES}"
|
439
|
+
end
|
440
|
+
if cache_dir.nil?
|
441
|
+
cache_dir = HF_HUB_CACHE
|
442
|
+
end
|
443
|
+
|
444
|
+
object_id = repo_id.gsub("/", "--")
|
445
|
+
repo_cache = File.join(cache_dir, "#{repo_type}s--#{object_id}")
|
446
|
+
if !Dir.exist?(repo_cache)
|
447
|
+
# No cache for this model
|
448
|
+
return nil
|
449
|
+
end
|
450
|
+
|
451
|
+
refs_dir = File.join(repo_cache, "refs")
|
452
|
+
snapshots_dir = File.join(repo_cache, "snapshots")
|
453
|
+
no_exist_dir = File.join(repo_cache, ".no_exist")
|
454
|
+
|
455
|
+
# Resolve refs (for instance to convert main to the associated commit sha)
|
456
|
+
if Dir.exist?(refs_dir)
|
457
|
+
revision_file = File.join(refs_dir, revision)
|
458
|
+
if File.exist?(revision_file)
|
459
|
+
revision = File.read(revision_file)
|
460
|
+
end
|
461
|
+
end
|
462
|
+
|
463
|
+
# Check if file is cached as "no_exist"
|
464
|
+
if File.exist?(File.join(no_exist_dir, revision, filename))
|
465
|
+
return CACHED_NO_EXIST
|
466
|
+
end
|
467
|
+
|
468
|
+
# Check if revision folder exists
|
469
|
+
if !Dir.exist?(snapshots_dir)
|
470
|
+
return nil
|
471
|
+
end
|
472
|
+
cached_shas = Dir.glob("*", base: snapshots_dir)
|
473
|
+
if !cached_shas.include?(revision)
|
474
|
+
# No cache for this revision and we won't try to return a random revision
|
475
|
+
return nil
|
476
|
+
end
|
477
|
+
|
478
|
+
# Check if file exists in cache
|
479
|
+
cached_file = File.join(snapshots_dir, revision, filename)
|
480
|
+
File.exist?(cached_file) ? cached_file : nil
|
481
|
+
end
|
482
|
+
|
483
|
+
def get_hf_file_metadata(
|
484
|
+
url,
|
485
|
+
token: nil,
|
486
|
+
proxies: nil,
|
487
|
+
timeout: DEFAULT_REQUEST_TIMEOUT,
|
488
|
+
library_name: nil,
|
489
|
+
library_version: nil,
|
490
|
+
user_agent: nil,
|
491
|
+
headers: nil
|
492
|
+
)
|
493
|
+
headers =
|
494
|
+
build_hf_headers(
|
495
|
+
token: token,
|
496
|
+
library_name: library_name,
|
497
|
+
library_version: library_version,
|
498
|
+
user_agent: user_agent,
|
499
|
+
headers: headers
|
500
|
+
)
|
501
|
+
headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file
|
502
|
+
|
503
|
+
# Retrieve metadata
|
504
|
+
r =
|
505
|
+
_request_wrapper(
|
506
|
+
"HEAD",
|
507
|
+
url,
|
508
|
+
headers: headers,
|
509
|
+
allow_redirects: false,
|
510
|
+
follow_relative_redirects: true,
|
511
|
+
proxies: proxies,
|
512
|
+
timeout: timeout
|
513
|
+
)
|
514
|
+
hf_raise_for_status(r)
|
515
|
+
|
516
|
+
# Return
|
517
|
+
HfFileMetadata.new(
|
518
|
+
commit_hash: r[HUGGINGFACE_HEADER_X_REPO_COMMIT],
|
519
|
+
# We favor a custom header indicating the etag of the linked resource, and
|
520
|
+
# we fallback to the regular etag header.
|
521
|
+
etag: _normalize_etag(r[HUGGINGFACE_HEADER_X_LINKED_ETAG] || r["etag"]),
|
522
|
+
# Either from response headers (if redirected) or defaults to request url
|
523
|
+
# Do not use directly `url`, as `_request_wrapper` might have followed relative
|
524
|
+
# redirects.
|
525
|
+
location: r["location"] || r.uri.to_s,
|
526
|
+
size: _int_or_none(r[HUGGINGFACE_HEADER_X_LINKED_SIZE] || r["content-length"])
|
527
|
+
)
|
528
|
+
end
|
529
|
+
|
530
|
+
def _get_metadata_or_catch_error(
|
531
|
+
repo_id:,
|
532
|
+
filename:,
|
533
|
+
repo_type:,
|
534
|
+
revision:,
|
535
|
+
endpoint:,
|
536
|
+
proxies:,
|
537
|
+
etag_timeout:,
|
538
|
+
headers:, # mutated inplace!
|
539
|
+
token:,
|
540
|
+
local_files_only:,
|
541
|
+
relative_filename: nil, # only used to store `.no_exists` in cache
|
542
|
+
storage_folder: nil # only used to store `.no_exists` in cache
|
543
|
+
)
|
544
|
+
if local_files_only
|
545
|
+
return [
|
546
|
+
nil,
|
547
|
+
nil,
|
548
|
+
nil,
|
549
|
+
nil,
|
550
|
+
OfflineModeIsEnabled.new(
|
551
|
+
"Cannot access file since 'local_files_only: true' as been set. (repo_id: #{repo_id}, repo_type: #{repo_type}, revision: #{revision}, filename: #{filename})"
|
552
|
+
)
|
553
|
+
]
|
554
|
+
end
|
555
|
+
|
556
|
+
url = hf_hub_url(repo_id, filename, repo_type: repo_type, revision: revision, endpoint: endpoint)
|
557
|
+
url_to_download = url
|
558
|
+
etag = nil
|
559
|
+
commit_hash = nil
|
560
|
+
expected_size = nil
|
561
|
+
head_error_call = nil
|
562
|
+
|
563
|
+
if !local_files_only
|
564
|
+
metadata = nil
|
565
|
+
begin
|
566
|
+
metadata =
|
567
|
+
get_hf_file_metadata(
|
568
|
+
url,
|
569
|
+
proxies: proxies,
|
570
|
+
timeout: etag_timeout,
|
571
|
+
headers: headers,
|
572
|
+
token: token
|
573
|
+
)
|
574
|
+
rescue => e
|
575
|
+
raise e
|
576
|
+
raise Todo
|
577
|
+
end
|
578
|
+
|
579
|
+
# Commit hash must exist
|
580
|
+
commit_hash = metadata.commit_hash
|
581
|
+
if commit_hash.nil?
|
582
|
+
raise Todo
|
583
|
+
end
|
584
|
+
|
585
|
+
# Etag must exist
|
586
|
+
etag = metadata.etag
|
587
|
+
if etag.nil?
|
588
|
+
raise Todo
|
589
|
+
end
|
590
|
+
|
591
|
+
# Expected (uncompressed) size
|
592
|
+
expected_size = metadata.size
|
593
|
+
if expected_size.nil?
|
594
|
+
raise Todo
|
595
|
+
end
|
596
|
+
|
597
|
+
if metadata.location != url
|
598
|
+
url_to_download = metadata.location
|
599
|
+
if netloc(URI.parse(url)) != netloc(URI.parse(metadata.location))
|
600
|
+
# Remove authorization header when downloading a LFS blob
|
601
|
+
headers.delete("authorization")
|
602
|
+
end
|
603
|
+
end
|
604
|
+
end
|
605
|
+
|
606
|
+
if !(local_files_only || !etag.nil? || !head_call_error.nil?)
|
607
|
+
raise "etag is empty due to uncovered problems"
|
608
|
+
end
|
609
|
+
|
610
|
+
[url_to_download, etag, commit_hash, expected_size, head_error_call]
|
611
|
+
end
|
612
|
+
|
613
|
+
def _raise_on_head_call_error(head_call_error, force_download, local_files_only)
|
614
|
+
# No head call => we cannot force download.
|
615
|
+
if force_download
|
616
|
+
if local_files_only
|
617
|
+
raise ArgumentError, "Cannot pass 'force_download: true' and 'local_files_only: true' at the same time."
|
618
|
+
elsif head_call_error.is_a?(OfflineModeIsEnabled)
|
619
|
+
raise ArgumentError, "Cannot pass 'force_download: true' when offline mode is enabled."
|
620
|
+
else
|
621
|
+
raise ArgumentError, "Force download failed due to the above error."
|
622
|
+
end
|
623
|
+
end
|
624
|
+
|
625
|
+
# If we couldn't find an appropriate file on disk, raise an error.
|
626
|
+
# If files cannot be found and local_files_only=True,
|
627
|
+
# the models might've been found if local_files_only=False
|
628
|
+
# Notify the user about that
|
629
|
+
if local_files_only
|
630
|
+
raise LocalEntryNotFoundError,
|
631
|
+
"Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable" +
|
632
|
+
" hf.co look-ups and downloads online, set 'local_files_only' to false."
|
633
|
+
elsif head_call_error.is_a?(RepositoryNotFoundError) || head_call_error.is_a?(GatedRepoError)
|
634
|
+
# Repo not found or gated => let's raise the actual error
|
635
|
+
raise head_call_error
|
636
|
+
else
|
637
|
+
# Otherwise: most likely a connection issue or Hub downtime => let's warn the user
|
638
|
+
raise LocalEntryNotFoundError,
|
639
|
+
"An error happened while trying to locate the file on the Hub and we cannot find the requested files" +
|
640
|
+
" in the local cache. Please check your connection and try again or make sure your Internet connection" +
|
641
|
+
" is on."
|
642
|
+
end
|
643
|
+
end
|
644
|
+
|
645
|
+
def _download_to_tmp_and_move(
|
646
|
+
incomplete_path:,
|
647
|
+
destination_path:,
|
648
|
+
url_to_download:,
|
649
|
+
proxies:,
|
650
|
+
headers:,
|
651
|
+
expected_size:,
|
652
|
+
filename:,
|
653
|
+
force_download:
|
654
|
+
)
|
655
|
+
if destination_path.exist? && !force_download
|
656
|
+
# Do nothing if already exists (except if force_download=True)
|
657
|
+
return
|
658
|
+
end
|
659
|
+
|
660
|
+
if incomplete_path.exist? && (force_download || (HF_HUB_ENABLE_HF_TRANSFER && !proxies))
|
661
|
+
# By default, we will try to resume the download if possible.
|
662
|
+
# However, if the user has set `force_download=True` or if `hf_transfer` is enabled, then we should
|
663
|
+
# not resume the download => delete the incomplete file.
|
664
|
+
message = "Removing incomplete file '#{incomplete_path}'"
|
665
|
+
if force_download
|
666
|
+
message += " (force_download: true)"
|
667
|
+
elsif HF_HUB_ENABLE_HF_TRANSFER && !proxies
|
668
|
+
message += " (hf_transfer: true)"
|
669
|
+
end
|
670
|
+
Transformers.logger.info(message)
|
671
|
+
incomplete_path.unlink #(missing_ok=True)
|
672
|
+
end
|
673
|
+
|
674
|
+
incomplete_path.open("ab") do |f|
|
675
|
+
f.seek(0, IO::SEEK_END)
|
676
|
+
resume_size = f.tell
|
677
|
+
message = "Downloading '#{filename}' to '#{incomplete_path}'"
|
678
|
+
if resume_size > 0 && !expected_size.nil?
|
679
|
+
message += " (resume from #{resume_size}/#{expected_size})"
|
680
|
+
end
|
681
|
+
Transformers.logger.info(message)
|
682
|
+
|
683
|
+
if !expected_size.nil? # might be None if HTTP header not set correctly
|
684
|
+
# Check disk space in both tmp and destination path
|
685
|
+
_check_disk_space(expected_size, incomplete_path.parent)
|
686
|
+
_check_disk_space(expected_size, destination_path.parent)
|
687
|
+
end
|
688
|
+
|
689
|
+
http_get(
|
690
|
+
url_to_download,
|
691
|
+
f,
|
692
|
+
proxies: proxies,
|
693
|
+
resume_size: resume_size,
|
694
|
+
headers: headers,
|
695
|
+
expected_size: expected_size,
|
696
|
+
)
|
697
|
+
end
|
698
|
+
|
699
|
+
Transformers.logger.info("Download complete. Moving file to #{destination_path}")
|
700
|
+
_chmod_and_move(incomplete_path, destination_path)
|
701
|
+
end
|
702
|
+
|
703
|
+
def _int_or_none(value)
|
704
|
+
value&.to_i
|
705
|
+
end
|
706
|
+
|
707
|
+
def _chmod_and_move(src, dst)
|
708
|
+
tmp_file = dst.parent.parent / "tmp_#{SecureRandom.uuid}"
|
709
|
+
begin
|
710
|
+
FileUtils.touch(tmp_file)
|
711
|
+
cache_dir_mode = Pathname.new(tmp_file).stat.mode
|
712
|
+
src.chmod(cache_dir_mode)
|
713
|
+
ensure
|
714
|
+
begin
|
715
|
+
tmp_file.unlink
|
716
|
+
rescue Errno::ENOENT
|
717
|
+
# fails if `tmp_file.touch()` failed => do nothing
|
718
|
+
# See https://github.com/huggingface/huggingface_hub/issues/2359
|
719
|
+
end
|
720
|
+
end
|
721
|
+
|
722
|
+
FileUtils.move(src.to_s, dst.to_s)
|
723
|
+
end
|
724
|
+
|
725
|
+
def _get_pointer_path(storage_folder, revision, relative_filename)
|
726
|
+
snapshot_path = File.join(storage_folder, "snapshots")
|
727
|
+
pointer_path = File.join(snapshot_path, revision, relative_filename)
|
728
|
+
if !parents(Pathname.new(File.absolute_path(pointer_path))).include?(Pathname.new(File.absolute_path(snapshot_path)))
|
729
|
+
raise ArgumentError,
|
730
|
+
"Invalid pointer path: cannot create pointer path in snapshot folder if" +
|
731
|
+
" `storage_folder: #{storage_folder.inspect}`, `revision: #{revision.inspect}` and" +
|
732
|
+
" `relative_filename: #{relative_filename.inspect}`."
|
733
|
+
end
|
734
|
+
pointer_path
|
735
|
+
end
|
736
|
+
|
737
|
+
# additional methods
|
738
|
+
|
739
|
+
def netloc(uri)
|
740
|
+
[uri.host, uri.port].compact.join(":")
|
741
|
+
end
|
742
|
+
|
743
|
+
def parents(path)
|
744
|
+
parents = []
|
745
|
+
100.times do
|
746
|
+
if path == path.parent
|
747
|
+
break
|
748
|
+
end
|
749
|
+
path = path.parent
|
750
|
+
parents << path
|
751
|
+
end
|
752
|
+
parents
|
753
|
+
end
|
754
|
+
|
755
|
+
def display_progress(filename, width, size, expected_size)
|
756
|
+
bar_width = width - (filename.length + 3)
|
757
|
+
progress = size / expected_size.to_f
|
758
|
+
done = (progress * bar_width).round
|
759
|
+
not_done = bar_width - done
|
760
|
+
"#{filename} |#{"█" * done}#{" " * not_done}|"
|
761
|
+
end
|
762
|
+
end
|
763
|
+
end
|
764
|
+
end
|