transformers-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,764 @@
1
+ module Transformers
2
+ module HfHub
3
+ # Return value when trying to load a file from cache but the file does not exist in the distant repo.
4
+ CACHED_NO_EXIST = Object.new
5
+
6
+ # Regex to get filename from a "Content-Disposition" header for CDN-served files
7
+ HEADER_FILENAME_PATTERN = /filename="(?<filename>.*?)";/
8
+
9
+ # Regex to check if the revision IS directly a commit_hash
10
+ REGEX_COMMIT_HASH = /^[0-9a-f]{40}$/
11
+
12
+ class HfFileMetadata
13
+ attr_reader :commit_hash, :etag, :location, :size
14
+
15
+ def initialize(commit_hash:, etag:, location:, size:)
16
+ @commit_hash = commit_hash
17
+ @etag = etag
18
+ @location = location
19
+ @size = size
20
+ end
21
+ end
22
+
23
+ class << self
24
+ def hf_hub_url(
25
+ repo_id,
26
+ filename,
27
+ subfolder: nil,
28
+ repo_type: nil,
29
+ revision: nil,
30
+ endpoint: nil
31
+ )
32
+ if subfolder == ""
33
+ subfolder = nil
34
+ end
35
+ if !subfolder.nil?
36
+ filename = "#{subfolder}/#{filename}"
37
+ end
38
+
39
+ if !REPO_TYPES.include?(repo_type)
40
+ raise ArgumentError, "Invalid repo type"
41
+ end
42
+
43
+ if REPO_TYPES_URL_PREFIXES.include?(repo_type)
44
+ repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
45
+ end
46
+
47
+ if revision.nil?
48
+ revision = DEFAULT_REVISION
49
+ end
50
+ url =
51
+ HUGGINGFACE_CO_URL_TEMPLATE %
52
+ {repo_id: repo_id, revision: CGI.escape(revision), filename: CGI.escape(filename)}
53
+ # Update endpoint if provided
54
+ if !endpoint.nil? && url.start_with?(ENDPOINT)
55
+ url = endpoint + url[ENDPOINT.length..]
56
+ end
57
+ url
58
+ end
59
+
60
+ def _request_wrapper(method, url, follow_relative_redirects: false, redirects: 0, **params)
61
+ # Recursively follow relative redirects
62
+ if follow_relative_redirects
63
+ if redirects > 10
64
+ raise "Too many redirects"
65
+ end
66
+
67
+ response = _request_wrapper(
68
+ method,
69
+ url,
70
+ follow_relative_redirects: false,
71
+ **params
72
+ )
73
+
74
+ # If redirection, we redirect only relative paths.
75
+ # This is useful in case of a renamed repository.
76
+ if response.is_a?(Net::HTTPRedirection)
77
+ parsed_target = URI.parse(response["Location"])
78
+ if netloc(parsed_target) == ""
79
+ # This means it is a relative 'location' headers, as allowed by RFC 7231.
80
+ # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
81
+ # We want to follow this relative redirect !
82
+ #
83
+ # Highly inspired by `resolve_redirects` from requests library.
84
+ # See https://github.com/psf/requests/blob/main/requests/sessions.py#L159
85
+ next_url = URI.parse(url)
86
+ next_url.path = parsed_target.path
87
+ return _request_wrapper(method, next_url, follow_relative_redirects: true, redirects: redirects + 1, **params)
88
+ end
89
+ end
90
+ return response
91
+ end
92
+
93
+ # Perform request and return if status_code is not in the retry list.
94
+ uri = URI.parse(url)
95
+
96
+ http_options = {use_ssl: true}
97
+ if params[:timeout]
98
+ http_options[:open_timeout] = params[:timeout]
99
+ http_options[:read_timeout] = params[:timeout]
100
+ http_options[:write_timeout] = params[:timeout]
101
+ end
102
+ response =
103
+ Net::HTTP.start(uri.host, uri.port, **http_options) do |http|
104
+ http.send_request(method, uri.path, nil, params[:headers])
105
+ end
106
+ response.uri ||= uri
107
+ hf_raise_for_status(response)
108
+ response
109
+ end
110
+
111
+ def http_get(
112
+ url,
113
+ temp_file,
114
+ proxies: nil,
115
+ resume_size: 0,
116
+ headers: nil,
117
+ expected_size: nil,
118
+ displayed_filename: nil,
119
+ _nb_retries: 5
120
+ )
121
+ uri = URI.parse(url)
122
+
123
+ if resume_size > 0
124
+ headers["range"] = "bytes=%d-" % [resume_size]
125
+ end
126
+
127
+ size = resume_size
128
+ Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http|
129
+ request = Net::HTTP::Get.new(uri)
130
+ headers.each do |k, v|
131
+ request[k] = v
132
+ end
133
+ http.request(request) do |response|
134
+ case response
135
+ when Net::HTTPSuccess
136
+ if displayed_filename.nil?
137
+ displayed_filename = url
138
+ content_disposition = response["content-disposition"]
139
+ if !content_disposition.nil?
140
+ match = HEADER_FILENAME_PATTERN.match(content_disposition)
141
+ if !match.nil?
142
+ # Means file is on CDN
143
+ displayed_filename = match["filename"]
144
+ end
145
+ end
146
+ end
147
+
148
+ stream = STDERR
149
+ tty = stream.tty?
150
+ width = tty ? stream.winsize[1] : 80
151
+
152
+ response.read_body do |chunk|
153
+ temp_file.write(chunk)
154
+ size += chunk.bytesize
155
+
156
+ if tty
157
+ stream.print "\r#{display_progress(displayed_filename, width, size, expected_size)}"
158
+ end
159
+ end
160
+
161
+ if tty
162
+ stream.puts
163
+ else
164
+ stream.puts display_progress(displayed_filename, width, size, expected_size)
165
+ end
166
+ else
167
+ hf_raise_for_status(response)
168
+ end
169
+ end
170
+ end
171
+ end
172
+
173
+ def _normalize_etag(etag)
174
+ if etag.nil?
175
+ return nil
176
+ end
177
+ etag.sub(/\A\W/, "").delete('"')
178
+ end
179
+
180
+ def _create_symlink(src, dst, new_blob: false)
181
+ begin
182
+ FileUtils.rm(dst)
183
+ rescue Errno::ENOENT
184
+ # do nothing
185
+ end
186
+
187
+ # abs_src = File.absolute_path(File.expand_path(src))
188
+ # abs_dst = File.absolute_path(File.expand_path(dst))
189
+ # abs_dst_folder = File.dirname(abs_dst)
190
+
191
+ FileUtils.symlink(src, dst)
192
+ end
193
+
194
+ def _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
195
+ if revision != commit_hash
196
+ ref_path = Pathname.new(storage_folder) / "refs" / revision
197
+ ref_path.parent.mkpath
198
+ if !ref_path.exist? || commit_hash != ref_path.read
199
+ # Update ref only if has been updated. Could cause useless error in case
200
+ # repo is already cached and user doesn't have write access to cache folder.
201
+ # See https://github.com/huggingface/huggingface_hub/issues/1216.
202
+ ref_path.write(commit_hash)
203
+ end
204
+ end
205
+ end
206
+
207
+ def repo_folder_name(repo_id:, repo_type:)
208
+ # remove all `/` occurrences to correctly convert repo to directory name
209
+ parts = ["#{repo_type}s"] + repo_id.split("/")
210
+ parts.join(REPO_ID_SEPARATOR)
211
+ end
212
+
213
+ def _check_disk_space(expected_size, target_dir)
214
+ # TODO
215
+ end
216
+
217
+ def hf_hub_download(
218
+ repo_id,
219
+ filename,
220
+ subfolder: nil,
221
+ repo_type: nil,
222
+ revision: nil,
223
+ library_name: nil,
224
+ library_version: nil,
225
+ cache_dir: nil,
226
+ local_dir: nil,
227
+ local_dir_use_symlinks: "auto",
228
+ user_agent: nil,
229
+ force_download: false,
230
+ force_filename: nil,
231
+ proxies: nil,
232
+ etag_timeout: DEFAULT_ETAG_TIMEOUT,
233
+ resume_download: false,
234
+ token: nil,
235
+ local_files_only: false,
236
+ legacy_cache_layout: false,
237
+ endpoint: nil
238
+ )
239
+ if cache_dir.nil?
240
+ cache_dir = HF_HUB_CACHE
241
+ end
242
+ if revision.nil?
243
+ revision = DEFAULT_REVISION
244
+ end
245
+
246
+ if subfolder == ""
247
+ subfolder = nil
248
+ end
249
+ if !subfolder.nil?
250
+ # This is used to create a URL, and not a local path, hence the forward slash.
251
+ filename = "#{subfolder}/#{filename}"
252
+ end
253
+
254
+ if repo_type.nil?
255
+ repo_type = "model"
256
+ end
257
+ if !REPO_TYPES.include?(repo_type)
258
+ raise ArgumentError, "Invalid repo type: #{repo_type}. Accepted repo types are: #{REPO_TYPES}"
259
+ end
260
+
261
+ headers =
262
+ build_hf_headers(
263
+ token: token,
264
+ library_name: library_name,
265
+ library_version: library_version,
266
+ user_agent: user_agent
267
+ )
268
+
269
+ if !local_dir.nil?
270
+ raise Todo
271
+ else
272
+ _hf_hub_download_to_cache_dir(
273
+ # Destination
274
+ cache_dir: cache_dir,
275
+ # File info
276
+ repo_id: repo_id,
277
+ filename: filename,
278
+ repo_type: repo_type,
279
+ revision: revision,
280
+ # HTTP info
281
+ endpoint: endpoint,
282
+ etag_timeout: etag_timeout,
283
+ headers: headers,
284
+ proxies: proxies,
285
+ token: token,
286
+ # Additional options
287
+ local_files_only: local_files_only,
288
+ force_download: force_download
289
+ )
290
+ end
291
+ end
292
+
293
+ def _hf_hub_download_to_cache_dir(
294
+ cache_dir:,
295
+ # File info
296
+ repo_id:,
297
+ filename:,
298
+ repo_type:,
299
+ revision:,
300
+ # HTTP info
301
+ endpoint:,
302
+ etag_timeout:,
303
+ headers:,
304
+ proxies:,
305
+ token:,
306
+ # Additional options
307
+ local_files_only:,
308
+ force_download:
309
+ )
310
+ _locks_dir = File.join(cache_dir, ".locks")
311
+ storage_folder = File.join(cache_dir, repo_folder_name(repo_id: repo_id, repo_type: repo_type))
312
+
313
+ # cross platform transcription of filename, to be used as a local file path.
314
+ relative_filename = File.join(*filename.split("/"))
315
+
316
+ # if user provides a commit_hash and they already have the file on disk, shortcut everything.
317
+ if REGEX_COMMIT_HASH.match?(revision)
318
+ pointer_path = _get_pointer_path(storage_folder, revision, relative_filename)
319
+ if File.exist?(pointer_path) && !force_download
320
+ return pointer_path
321
+ end
322
+ end
323
+
324
+ # Try to get metadata (etag, commit_hash, url, size) from the server.
325
+ # If we can't, a HEAD request error is returned.
326
+ url_to_download, etag, commit_hash, expected_size, head_call_error = _get_metadata_or_catch_error(
327
+ repo_id: repo_id,
328
+ filename: filename,
329
+ repo_type: repo_type,
330
+ revision: revision,
331
+ endpoint: endpoint,
332
+ proxies: proxies,
333
+ etag_timeout: etag_timeout,
334
+ headers: headers,
335
+ token: token,
336
+ local_files_only: local_files_only,
337
+ storage_folder: storage_folder,
338
+ relative_filename: relative_filename
339
+ )
340
+
341
+ # etag can be None for several reasons:
342
+ # 1. we passed local_files_only.
343
+ # 2. we don't have a connection
344
+ # 3. Hub is down (HTTP 500 or 504)
345
+ # 4. repo is not found -for example private or gated- and invalid/missing token sent
346
+ # 5. Hub is blocked by a firewall or proxy is not set correctly.
347
+ # => Try to get the last downloaded one from the specified revision.
348
+ #
349
+ # If the specified revision is a commit hash, look inside "snapshots".
350
+ # If the specified revision is a branch or tag, look inside "refs".
351
+ if !head_call_error.nil?
352
+ # Couldn't make a HEAD call => let's try to find a local file
353
+ if !force_download
354
+ commit_hash = nil
355
+ if REGEX_COMMIT_HASH.match(revision)
356
+ commit_hash = revision
357
+ else
358
+ ref_path = File.join(storage_folder, "refs", revision)
359
+ if File.exist?(ref_path)
360
+ commit_hash = File.read(ref_path)
361
+ end
362
+ end
363
+
364
+ # Return pointer file if exists
365
+ if !commit_hash.nil?
366
+ pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
367
+ if File.exist?(pointer_path) && !force_download
368
+ return pointer_path
369
+ end
370
+ end
371
+ end
372
+
373
+ # Otherwise, raise appropriate error
374
+ _raise_on_head_call_error(head_call_error, force_download, local_files_only)
375
+ end
376
+
377
+ # From now on, etag and commit_hash are not None.
378
+ raise "etag must have been retrieved from server" if etag.nil?
379
+ raise "commit_hash must have been retrieved from server" if commit_hash.nil?
380
+ raise "file location must have been retrieved from server" if url_to_download.nil?
381
+ raise "expected_size must have been retrieved from server" if expected_size.nil?
382
+ blob_path = File.join(storage_folder, "blobs", etag)
383
+ pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
384
+
385
+ FileUtils.mkdir_p(File.dirname(blob_path))
386
+ FileUtils.mkdir_p(File.dirname(pointer_path))
387
+
388
+ # if passed revision is not identical to commit_hash
389
+ # then revision has to be a branch name or tag name.
390
+ # In that case store a ref.
391
+ _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
392
+
393
+ if !force_download
394
+ if File.exist?(pointer_path)
395
+ return pointer_path
396
+ end
397
+
398
+ if File.exist?(blob_path)
399
+ # we have the blob already, but not the pointer
400
+ _create_symlink(blob_path, pointer_path, new_blob: false)
401
+ return pointer_path
402
+ end
403
+ end
404
+
405
+ # Prevent parallel downloads of the same file with a lock.
406
+ # etag could be duplicated across repos,
407
+ # lock_path = File.join(locks_dir, repo_folder_name(repo_id: repo_id, repo_type: repo_type), "#{etag}.lock")
408
+
409
+ _download_to_tmp_and_move(
410
+ incomplete_path: Pathname.new(blob_path + ".incomplete"),
411
+ destination_path: Pathname.new(blob_path),
412
+ url_to_download: url_to_download,
413
+ proxies: proxies,
414
+ headers: headers,
415
+ expected_size: expected_size,
416
+ filename: filename,
417
+ force_download: force_download
418
+ )
419
+ _create_symlink(blob_path, pointer_path, new_blob: true)
420
+
421
+ pointer_path
422
+ end
423
+
424
+ def try_to_load_from_cache(
425
+ repo_id,
426
+ filename,
427
+ cache_dir: nil,
428
+ revision: nil,
429
+ repo_type: nil
430
+ )
431
+ if revision.nil?
432
+ revision = "main"
433
+ end
434
+ if repo_type.nil?
435
+ repo_type = "model"
436
+ end
437
+ if !REPO_TYPES.include?(repo_type)
438
+ raise ArgumentError, "Invalid repo type: #{repo_type}. Accepted repo types are: #{REPO_TYPES}"
439
+ end
440
+ if cache_dir.nil?
441
+ cache_dir = HF_HUB_CACHE
442
+ end
443
+
444
+ object_id = repo_id.gsub("/", "--")
445
+ repo_cache = File.join(cache_dir, "#{repo_type}s--#{object_id}")
446
+ if !Dir.exist?(repo_cache)
447
+ # No cache for this model
448
+ return nil
449
+ end
450
+
451
+ refs_dir = File.join(repo_cache, "refs")
452
+ snapshots_dir = File.join(repo_cache, "snapshots")
453
+ no_exist_dir = File.join(repo_cache, ".no_exist")
454
+
455
+ # Resolve refs (for instance to convert main to the associated commit sha)
456
+ if Dir.exist?(refs_dir)
457
+ revision_file = File.join(refs_dir, revision)
458
+ if File.exist?(revision_file)
459
+ revision = File.read(revision_file)
460
+ end
461
+ end
462
+
463
+ # Check if file is cached as "no_exist"
464
+ if File.exist?(File.join(no_exist_dir, revision, filename))
465
+ return CACHED_NO_EXIST
466
+ end
467
+
468
+ # Check if revision folder exists
469
+ if !Dir.exist?(snapshots_dir)
470
+ return nil
471
+ end
472
+ cached_shas = Dir.glob("*", base: snapshots_dir)
473
+ if !cached_shas.include?(revision)
474
+ # No cache for this revision and we won't try to return a random revision
475
+ return nil
476
+ end
477
+
478
+ # Check if file exists in cache
479
+ cached_file = File.join(snapshots_dir, revision, filename)
480
+ File.exist?(cached_file) ? cached_file : nil
481
+ end
482
+
483
+ def get_hf_file_metadata(
484
+ url,
485
+ token: nil,
486
+ proxies: nil,
487
+ timeout: DEFAULT_REQUEST_TIMEOUT,
488
+ library_name: nil,
489
+ library_version: nil,
490
+ user_agent: nil,
491
+ headers: nil
492
+ )
493
+ headers =
494
+ build_hf_headers(
495
+ token: token,
496
+ library_name: library_name,
497
+ library_version: library_version,
498
+ user_agent: user_agent,
499
+ headers: headers
500
+ )
501
+ headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file
502
+
503
+ # Retrieve metadata
504
+ r =
505
+ _request_wrapper(
506
+ "HEAD",
507
+ url,
508
+ headers: headers,
509
+ allow_redirects: false,
510
+ follow_relative_redirects: true,
511
+ proxies: proxies,
512
+ timeout: timeout
513
+ )
514
+ hf_raise_for_status(r)
515
+
516
+ # Return
517
+ HfFileMetadata.new(
518
+ commit_hash: r[HUGGINGFACE_HEADER_X_REPO_COMMIT],
519
+ # We favor a custom header indicating the etag of the linked resource, and
520
+ # we fallback to the regular etag header.
521
+ etag: _normalize_etag(r[HUGGINGFACE_HEADER_X_LINKED_ETAG] || r["etag"]),
522
+ # Either from response headers (if redirected) or defaults to request url
523
+ # Do not use directly `url`, as `_request_wrapper` might have followed relative
524
+ # redirects.
525
+ location: r["location"] || r.uri.to_s,
526
+ size: _int_or_none(r[HUGGINGFACE_HEADER_X_LINKED_SIZE] || r["content-length"])
527
+ )
528
+ end
529
+
530
+ def _get_metadata_or_catch_error(
531
+ repo_id:,
532
+ filename:,
533
+ repo_type:,
534
+ revision:,
535
+ endpoint:,
536
+ proxies:,
537
+ etag_timeout:,
538
+ headers:, # mutated inplace!
539
+ token:,
540
+ local_files_only:,
541
+ relative_filename: nil, # only used to store `.no_exists` in cache
542
+ storage_folder: nil # only used to store `.no_exists` in cache
543
+ )
544
+ if local_files_only
545
+ return [
546
+ nil,
547
+ nil,
548
+ nil,
549
+ nil,
550
+ OfflineModeIsEnabled.new(
551
+ "Cannot access file since 'local_files_only: true' as been set. (repo_id: #{repo_id}, repo_type: #{repo_type}, revision: #{revision}, filename: #{filename})"
552
+ )
553
+ ]
554
+ end
555
+
556
+ url = hf_hub_url(repo_id, filename, repo_type: repo_type, revision: revision, endpoint: endpoint)
557
+ url_to_download = url
558
+ etag = nil
559
+ commit_hash = nil
560
+ expected_size = nil
561
+ head_error_call = nil
562
+
563
+ if !local_files_only
564
+ metadata = nil
565
+ begin
566
+ metadata =
567
+ get_hf_file_metadata(
568
+ url,
569
+ proxies: proxies,
570
+ timeout: etag_timeout,
571
+ headers: headers,
572
+ token: token
573
+ )
574
+ rescue => e
575
+ raise e
576
+ raise Todo
577
+ end
578
+
579
+ # Commit hash must exist
580
+ commit_hash = metadata.commit_hash
581
+ if commit_hash.nil?
582
+ raise Todo
583
+ end
584
+
585
+ # Etag must exist
586
+ etag = metadata.etag
587
+ if etag.nil?
588
+ raise Todo
589
+ end
590
+
591
+ # Expected (uncompressed) size
592
+ expected_size = metadata.size
593
+ if expected_size.nil?
594
+ raise Todo
595
+ end
596
+
597
+ if metadata.location != url
598
+ url_to_download = metadata.location
599
+ if netloc(URI.parse(url)) != netloc(URI.parse(metadata.location))
600
+ # Remove authorization header when downloading a LFS blob
601
+ headers.delete("authorization")
602
+ end
603
+ end
604
+ end
605
+
606
+ if !(local_files_only || !etag.nil? || !head_call_error.nil?)
607
+ raise "etag is empty due to uncovered problems"
608
+ end
609
+
610
+ [url_to_download, etag, commit_hash, expected_size, head_error_call]
611
+ end
612
+
613
+ def _raise_on_head_call_error(head_call_error, force_download, local_files_only)
614
+ # No head call => we cannot force download.
615
+ if force_download
616
+ if local_files_only
617
+ raise ArgumentError, "Cannot pass 'force_download: true' and 'local_files_only: true' at the same time."
618
+ elsif head_call_error.is_a?(OfflineModeIsEnabled)
619
+ raise ArgumentError, "Cannot pass 'force_download: true' when offline mode is enabled."
620
+ else
621
+ raise ArgumentError, "Force download failed due to the above error."
622
+ end
623
+ end
624
+
625
+ # If we couldn't find an appropriate file on disk, raise an error.
626
+ # If files cannot be found and local_files_only=True,
627
+ # the models might've been found if local_files_only=False
628
+ # Notify the user about that
629
+ if local_files_only
630
+ raise LocalEntryNotFoundError,
631
+ "Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable" +
632
+ " hf.co look-ups and downloads online, set 'local_files_only' to false."
633
+ elsif head_call_error.is_a?(RepositoryNotFoundError) || head_call_error.is_a?(GatedRepoError)
634
+ # Repo not found or gated => let's raise the actual error
635
+ raise head_call_error
636
+ else
637
+ # Otherwise: most likely a connection issue or Hub downtime => let's warn the user
638
+ raise LocalEntryNotFoundError,
639
+ "An error happened while trying to locate the file on the Hub and we cannot find the requested files" +
640
+ " in the local cache. Please check your connection and try again or make sure your Internet connection" +
641
+ " is on."
642
+ end
643
+ end
644
+
645
+ def _download_to_tmp_and_move(
646
+ incomplete_path:,
647
+ destination_path:,
648
+ url_to_download:,
649
+ proxies:,
650
+ headers:,
651
+ expected_size:,
652
+ filename:,
653
+ force_download:
654
+ )
655
+ if destination_path.exist? && !force_download
656
+ # Do nothing if already exists (except if force_download=True)
657
+ return
658
+ end
659
+
660
+ if incomplete_path.exist? && (force_download || (HF_HUB_ENABLE_HF_TRANSFER && !proxies))
661
+ # By default, we will try to resume the download if possible.
662
+ # However, if the user has set `force_download=True` or if `hf_transfer` is enabled, then we should
663
+ # not resume the download => delete the incomplete file.
664
+ message = "Removing incomplete file '#{incomplete_path}'"
665
+ if force_download
666
+ message += " (force_download: true)"
667
+ elsif HF_HUB_ENABLE_HF_TRANSFER && !proxies
668
+ message += " (hf_transfer: true)"
669
+ end
670
+ Transformers.logger.info(message)
671
+ incomplete_path.unlink #(missing_ok=True)
672
+ end
673
+
674
+ incomplete_path.open("ab") do |f|
675
+ f.seek(0, IO::SEEK_END)
676
+ resume_size = f.tell
677
+ message = "Downloading '#{filename}' to '#{incomplete_path}'"
678
+ if resume_size > 0 && !expected_size.nil?
679
+ message += " (resume from #{resume_size}/#{expected_size})"
680
+ end
681
+ Transformers.logger.info(message)
682
+
683
+ if !expected_size.nil? # might be None if HTTP header not set correctly
684
+ # Check disk space in both tmp and destination path
685
+ _check_disk_space(expected_size, incomplete_path.parent)
686
+ _check_disk_space(expected_size, destination_path.parent)
687
+ end
688
+
689
+ http_get(
690
+ url_to_download,
691
+ f,
692
+ proxies: proxies,
693
+ resume_size: resume_size,
694
+ headers: headers,
695
+ expected_size: expected_size,
696
+ )
697
+ end
698
+
699
+ Transformers.logger.info("Download complete. Moving file to #{destination_path}")
700
+ _chmod_and_move(incomplete_path, destination_path)
701
+ end
702
+
703
+ def _int_or_none(value)
704
+ value&.to_i
705
+ end
706
+
707
+ def _chmod_and_move(src, dst)
708
+ tmp_file = dst.parent.parent / "tmp_#{SecureRandom.uuid}"
709
+ begin
710
+ FileUtils.touch(tmp_file)
711
+ cache_dir_mode = Pathname.new(tmp_file).stat.mode
712
+ src.chmod(cache_dir_mode)
713
+ ensure
714
+ begin
715
+ tmp_file.unlink
716
+ rescue Errno::ENOENT
717
+ # fails if `tmp_file.touch()` failed => do nothing
718
+ # See https://github.com/huggingface/huggingface_hub/issues/2359
719
+ end
720
+ end
721
+
722
+ FileUtils.move(src.to_s, dst.to_s)
723
+ end
724
+
725
+ def _get_pointer_path(storage_folder, revision, relative_filename)
726
+ snapshot_path = File.join(storage_folder, "snapshots")
727
+ pointer_path = File.join(snapshot_path, revision, relative_filename)
728
+ if !parents(Pathname.new(File.absolute_path(pointer_path))).include?(Pathname.new(File.absolute_path(snapshot_path)))
729
+ raise ArgumentError,
730
+ "Invalid pointer path: cannot create pointer path in snapshot folder if" +
731
+ " `storage_folder: #{storage_folder.inspect}`, `revision: #{revision.inspect}` and" +
732
+ " `relative_filename: #{relative_filename.inspect}`."
733
+ end
734
+ pointer_path
735
+ end
736
+
737
+ # additional methods
738
+
739
+ def netloc(uri)
740
+ [uri.host, uri.port].compact.join(":")
741
+ end
742
+
743
+ def parents(path)
744
+ parents = []
745
+ 100.times do
746
+ if path == path.parent
747
+ break
748
+ end
749
+ path = path.parent
750
+ parents << path
751
+ end
752
+ parents
753
+ end
754
+
755
+ def display_progress(filename, width, size, expected_size)
756
+ bar_width = width - (filename.length + 3)
757
+ progress = size / expected_size.to_f
758
+ done = (progress * bar_width).round
759
+ not_done = bar_width - done
760
+ "#{filename} |#{"█" * done}#{" " * not_done}|"
761
+ end
762
+ end
763
+ end
764
+ end