transformers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,764 @@
1
+ module Transformers
2
+ module HfHub
3
+ # Return value when trying to load a file from cache but the file does not exist in the distant repo.
4
+ CACHED_NO_EXIST = Object.new
5
+
6
+ # Regex to get filename from a "Content-Disposition" header for CDN-served files
7
+ HEADER_FILENAME_PATTERN = /filename="(?<filename>.*?)";/
8
+
9
+ # Regex to check if the revision IS directly a commit_hash
10
+ REGEX_COMMIT_HASH = /^[0-9a-f]{40}$/
11
+
12
+ class HfFileMetadata
13
+ attr_reader :commit_hash, :etag, :location, :size
14
+
15
+ def initialize(commit_hash:, etag:, location:, size:)
16
+ @commit_hash = commit_hash
17
+ @etag = etag
18
+ @location = location
19
+ @size = size
20
+ end
21
+ end
22
+
23
+ class << self
24
+ def hf_hub_url(
25
+ repo_id,
26
+ filename,
27
+ subfolder: nil,
28
+ repo_type: nil,
29
+ revision: nil,
30
+ endpoint: nil
31
+ )
32
+ if subfolder == ""
33
+ subfolder = nil
34
+ end
35
+ if !subfolder.nil?
36
+ filename = "#{subfolder}/#{filename}"
37
+ end
38
+
39
+ if !REPO_TYPES.include?(repo_type)
40
+ raise ArgumentError, "Invalid repo type"
41
+ end
42
+
43
+ if REPO_TYPES_URL_PREFIXES.include?(repo_type)
44
+ repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
45
+ end
46
+
47
+ if revision.nil?
48
+ revision = DEFAULT_REVISION
49
+ end
50
+ url =
51
+ HUGGINGFACE_CO_URL_TEMPLATE %
52
+ {repo_id: repo_id, revision: CGI.escape(revision), filename: CGI.escape(filename)}
53
+ # Update endpoint if provided
54
+ if !endpoint.nil? && url.start_with?(ENDPOINT)
55
+ url = endpoint + url[ENDPOINT.length..]
56
+ end
57
+ url
58
+ end
59
+
60
+ def _request_wrapper(method, url, follow_relative_redirects: false, redirects: 0, **params)
61
+ # Recursively follow relative redirects
62
+ if follow_relative_redirects
63
+ if redirects > 10
64
+ raise "Too many redirects"
65
+ end
66
+
67
+ response = _request_wrapper(
68
+ method,
69
+ url,
70
+ follow_relative_redirects: false,
71
+ **params
72
+ )
73
+
74
+ # If redirection, we redirect only relative paths.
75
+ # This is useful in case of a renamed repository.
76
+ if response.is_a?(Net::HTTPRedirection)
77
+ parsed_target = URI.parse(response["Location"])
78
+ if netloc(parsed_target) == ""
79
+ # This means it is a relative 'location' headers, as allowed by RFC 7231.
80
+ # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
81
+ # We want to follow this relative redirect !
82
+ #
83
+ # Highly inspired by `resolve_redirects` from requests library.
84
+ # See https://github.com/psf/requests/blob/main/requests/sessions.py#L159
85
+ next_url = URI.parse(url)
86
+ next_url.path = parsed_target.path
87
+ return _request_wrapper(method, next_url, follow_relative_redirects: true, redirects: redirects + 1, **params)
88
+ end
89
+ end
90
+ return response
91
+ end
92
+
93
+ # Perform request and return if status_code is not in the retry list.
94
+ uri = URI.parse(url)
95
+
96
+ http_options = {use_ssl: true}
97
+ if params[:timeout]
98
+ http_options[:open_timeout] = params[:timeout]
99
+ http_options[:read_timeout] = params[:timeout]
100
+ http_options[:write_timeout] = params[:timeout]
101
+ end
102
+ response =
103
+ Net::HTTP.start(uri.host, uri.port, **http_options) do |http|
104
+ http.send_request(method, uri.path, nil, params[:headers])
105
+ end
106
+ response.uri ||= uri
107
+ hf_raise_for_status(response)
108
+ response
109
+ end
110
+
111
+ def http_get(
112
+ url,
113
+ temp_file,
114
+ proxies: nil,
115
+ resume_size: 0,
116
+ headers: nil,
117
+ expected_size: nil,
118
+ displayed_filename: nil,
119
+ _nb_retries: 5
120
+ )
121
+ uri = URI.parse(url)
122
+
123
+ if resume_size > 0
124
+ headers["range"] = "bytes=%d-" % [resume_size]
125
+ end
126
+
127
+ size = resume_size
128
+ Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http|
129
+ request = Net::HTTP::Get.new(uri)
130
+ headers.each do |k, v|
131
+ request[k] = v
132
+ end
133
+ http.request(request) do |response|
134
+ case response
135
+ when Net::HTTPSuccess
136
+ if displayed_filename.nil?
137
+ displayed_filename = url
138
+ content_disposition = response["content-disposition"]
139
+ if !content_disposition.nil?
140
+ match = HEADER_FILENAME_PATTERN.match(content_disposition)
141
+ if !match.nil?
142
+ # Means file is on CDN
143
+ displayed_filename = match["filename"]
144
+ end
145
+ end
146
+ end
147
+
148
+ stream = STDERR
149
+ tty = stream.tty?
150
+ width = tty ? stream.winsize[1] : 80
151
+
152
+ response.read_body do |chunk|
153
+ temp_file.write(chunk)
154
+ size += chunk.bytesize
155
+
156
+ if tty
157
+ stream.print "\r#{display_progress(displayed_filename, width, size, expected_size)}"
158
+ end
159
+ end
160
+
161
+ if tty
162
+ stream.puts
163
+ else
164
+ stream.puts display_progress(displayed_filename, width, size, expected_size)
165
+ end
166
+ else
167
+ hf_raise_for_status(response)
168
+ end
169
+ end
170
+ end
171
+ end
172
+
173
+ def _normalize_etag(etag)
174
+ if etag.nil?
175
+ return nil
176
+ end
177
+ etag.sub(/\A\W/, "").delete('"')
178
+ end
179
+
180
+ def _create_symlink(src, dst, new_blob: false)
181
+ begin
182
+ FileUtils.rm(dst)
183
+ rescue Errno::ENOENT
184
+ # do nothing
185
+ end
186
+
187
+ # abs_src = File.absolute_path(File.expand_path(src))
188
+ # abs_dst = File.absolute_path(File.expand_path(dst))
189
+ # abs_dst_folder = File.dirname(abs_dst)
190
+
191
+ FileUtils.symlink(src, dst)
192
+ end
193
+
194
+ def _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
195
+ if revision != commit_hash
196
+ ref_path = Pathname.new(storage_folder) / "refs" / revision
197
+ ref_path.parent.mkpath
198
+ if !ref_path.exist? || commit_hash != ref_path.read
199
+ # Update ref only if has been updated. Could cause useless error in case
200
+ # repo is already cached and user doesn't have write access to cache folder.
201
+ # See https://github.com/huggingface/huggingface_hub/issues/1216.
202
+ ref_path.write(commit_hash)
203
+ end
204
+ end
205
+ end
206
+
207
+ def repo_folder_name(repo_id:, repo_type:)
208
+ # remove all `/` occurrences to correctly convert repo to directory name
209
+ parts = ["#{repo_type}s"] + repo_id.split("/")
210
+ parts.join(REPO_ID_SEPARATOR)
211
+ end
212
+
213
+ def _check_disk_space(expected_size, target_dir)
214
+ # TODO
215
+ end
216
+
217
+ def hf_hub_download(
218
+ repo_id,
219
+ filename,
220
+ subfolder: nil,
221
+ repo_type: nil,
222
+ revision: nil,
223
+ library_name: nil,
224
+ library_version: nil,
225
+ cache_dir: nil,
226
+ local_dir: nil,
227
+ local_dir_use_symlinks: "auto",
228
+ user_agent: nil,
229
+ force_download: false,
230
+ force_filename: nil,
231
+ proxies: nil,
232
+ etag_timeout: DEFAULT_ETAG_TIMEOUT,
233
+ resume_download: false,
234
+ token: nil,
235
+ local_files_only: false,
236
+ legacy_cache_layout: false,
237
+ endpoint: nil
238
+ )
239
+ if cache_dir.nil?
240
+ cache_dir = HF_HUB_CACHE
241
+ end
242
+ if revision.nil?
243
+ revision = DEFAULT_REVISION
244
+ end
245
+
246
+ if subfolder == ""
247
+ subfolder = nil
248
+ end
249
+ if !subfolder.nil?
250
+ # This is used to create a URL, and not a local path, hence the forward slash.
251
+ filename = "#{subfolder}/#{filename}"
252
+ end
253
+
254
+ if repo_type.nil?
255
+ repo_type = "model"
256
+ end
257
+ if !REPO_TYPES.include?(repo_type)
258
+ raise ArgumentError, "Invalid repo type: #{repo_type}. Accepted repo types are: #{REPO_TYPES}"
259
+ end
260
+
261
+ headers =
262
+ build_hf_headers(
263
+ token: token,
264
+ library_name: library_name,
265
+ library_version: library_version,
266
+ user_agent: user_agent
267
+ )
268
+
269
+ if !local_dir.nil?
270
+ raise Todo
271
+ else
272
+ _hf_hub_download_to_cache_dir(
273
+ # Destination
274
+ cache_dir: cache_dir,
275
+ # File info
276
+ repo_id: repo_id,
277
+ filename: filename,
278
+ repo_type: repo_type,
279
+ revision: revision,
280
+ # HTTP info
281
+ endpoint: endpoint,
282
+ etag_timeout: etag_timeout,
283
+ headers: headers,
284
+ proxies: proxies,
285
+ token: token,
286
+ # Additional options
287
+ local_files_only: local_files_only,
288
+ force_download: force_download
289
+ )
290
+ end
291
+ end
292
+
293
+ def _hf_hub_download_to_cache_dir(
294
+ cache_dir:,
295
+ # File info
296
+ repo_id:,
297
+ filename:,
298
+ repo_type:,
299
+ revision:,
300
+ # HTTP info
301
+ endpoint:,
302
+ etag_timeout:,
303
+ headers:,
304
+ proxies:,
305
+ token:,
306
+ # Additional options
307
+ local_files_only:,
308
+ force_download:
309
+ )
310
+ _locks_dir = File.join(cache_dir, ".locks")
311
+ storage_folder = File.join(cache_dir, repo_folder_name(repo_id: repo_id, repo_type: repo_type))
312
+
313
+ # cross platform transcription of filename, to be used as a local file path.
314
+ relative_filename = File.join(*filename.split("/"))
315
+
316
+ # if user provides a commit_hash and they already have the file on disk, shortcut everything.
317
+ if REGEX_COMMIT_HASH.match?(revision)
318
+ pointer_path = _get_pointer_path(storage_folder, revision, relative_filename)
319
+ if File.exist?(pointer_path) && !force_download
320
+ return pointer_path
321
+ end
322
+ end
323
+
324
+ # Try to get metadata (etag, commit_hash, url, size) from the server.
325
+ # If we can't, a HEAD request error is returned.
326
+ url_to_download, etag, commit_hash, expected_size, head_call_error = _get_metadata_or_catch_error(
327
+ repo_id: repo_id,
328
+ filename: filename,
329
+ repo_type: repo_type,
330
+ revision: revision,
331
+ endpoint: endpoint,
332
+ proxies: proxies,
333
+ etag_timeout: etag_timeout,
334
+ headers: headers,
335
+ token: token,
336
+ local_files_only: local_files_only,
337
+ storage_folder: storage_folder,
338
+ relative_filename: relative_filename
339
+ )
340
+
341
+ # etag can be None for several reasons:
342
+ # 1. we passed local_files_only.
343
+ # 2. we don't have a connection
344
+ # 3. Hub is down (HTTP 500 or 504)
345
+ # 4. repo is not found -for example private or gated- and invalid/missing token sent
346
+ # 5. Hub is blocked by a firewall or proxy is not set correctly.
347
+ # => Try to get the last downloaded one from the specified revision.
348
+ #
349
+ # If the specified revision is a commit hash, look inside "snapshots".
350
+ # If the specified revision is a branch or tag, look inside "refs".
351
+ if !head_call_error.nil?
352
+ # Couldn't make a HEAD call => let's try to find a local file
353
+ if !force_download
354
+ commit_hash = nil
355
+ if REGEX_COMMIT_HASH.match(revision)
356
+ commit_hash = revision
357
+ else
358
+ ref_path = File.join(storage_folder, "refs", revision)
359
+ if File.exist?(ref_path)
360
+ commit_hash = File.read(ref_path)
361
+ end
362
+ end
363
+
364
+ # Return pointer file if exists
365
+ if !commit_hash.nil?
366
+ pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
367
+ if File.exist?(pointer_path) && !force_download
368
+ return pointer_path
369
+ end
370
+ end
371
+ end
372
+
373
+ # Otherwise, raise appropriate error
374
+ _raise_on_head_call_error(head_call_error, force_download, local_files_only)
375
+ end
376
+
377
+ # From now on, etag and commit_hash are not None.
378
+ raise "etag must have been retrieved from server" if etag.nil?
379
+ raise "commit_hash must have been retrieved from server" if commit_hash.nil?
380
+ raise "file location must have been retrieved from server" if url_to_download.nil?
381
+ raise "expected_size must have been retrieved from server" if expected_size.nil?
382
+ blob_path = File.join(storage_folder, "blobs", etag)
383
+ pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
384
+
385
+ FileUtils.mkdir_p(File.dirname(blob_path))
386
+ FileUtils.mkdir_p(File.dirname(pointer_path))
387
+
388
+ # if passed revision is not identical to commit_hash
389
+ # then revision has to be a branch name or tag name.
390
+ # In that case store a ref.
391
+ _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)
392
+
393
+ if !force_download
394
+ if File.exist?(pointer_path)
395
+ return pointer_path
396
+ end
397
+
398
+ if File.exist?(blob_path)
399
+ # we have the blob already, but not the pointer
400
+ _create_symlink(blob_path, pointer_path, new_blob: false)
401
+ return pointer_path
402
+ end
403
+ end
404
+
405
+ # Prevent parallel downloads of the same file with a lock.
406
+ # etag could be duplicated across repos,
407
+ # lock_path = File.join(locks_dir, repo_folder_name(repo_id: repo_id, repo_type: repo_type), "#{etag}.lock")
408
+
409
+ _download_to_tmp_and_move(
410
+ incomplete_path: Pathname.new(blob_path + ".incomplete"),
411
+ destination_path: Pathname.new(blob_path),
412
+ url_to_download: url_to_download,
413
+ proxies: proxies,
414
+ headers: headers,
415
+ expected_size: expected_size,
416
+ filename: filename,
417
+ force_download: force_download
418
+ )
419
+ _create_symlink(blob_path, pointer_path, new_blob: true)
420
+
421
+ pointer_path
422
+ end
423
+
424
+ def try_to_load_from_cache(
425
+ repo_id,
426
+ filename,
427
+ cache_dir: nil,
428
+ revision: nil,
429
+ repo_type: nil
430
+ )
431
+ if revision.nil?
432
+ revision = "main"
433
+ end
434
+ if repo_type.nil?
435
+ repo_type = "model"
436
+ end
437
+ if !REPO_TYPES.include?(repo_type)
438
+ raise ArgumentError, "Invalid repo type: #{repo_type}. Accepted repo types are: #{REPO_TYPES}"
439
+ end
440
+ if cache_dir.nil?
441
+ cache_dir = HF_HUB_CACHE
442
+ end
443
+
444
+ object_id = repo_id.gsub("/", "--")
445
+ repo_cache = File.join(cache_dir, "#{repo_type}s--#{object_id}")
446
+ if !Dir.exist?(repo_cache)
447
+ # No cache for this model
448
+ return nil
449
+ end
450
+
451
+ refs_dir = File.join(repo_cache, "refs")
452
+ snapshots_dir = File.join(repo_cache, "snapshots")
453
+ no_exist_dir = File.join(repo_cache, ".no_exist")
454
+
455
+ # Resolve refs (for instance to convert main to the associated commit sha)
456
+ if Dir.exist?(refs_dir)
457
+ revision_file = File.join(refs_dir, revision)
458
+ if File.exist?(revision_file)
459
+ revision = File.read(revision_file)
460
+ end
461
+ end
462
+
463
+ # Check if file is cached as "no_exist"
464
+ if File.exist?(File.join(no_exist_dir, revision, filename))
465
+ return CACHED_NO_EXIST
466
+ end
467
+
468
+ # Check if revision folder exists
469
+ if !Dir.exist?(snapshots_dir)
470
+ return nil
471
+ end
472
+ cached_shas = Dir.glob("*", base: snapshots_dir)
473
+ if !cached_shas.include?(revision)
474
+ # No cache for this revision and we won't try to return a random revision
475
+ return nil
476
+ end
477
+
478
+ # Check if file exists in cache
479
+ cached_file = File.join(snapshots_dir, revision, filename)
480
+ File.exist?(cached_file) ? cached_file : nil
481
+ end
482
+
483
+ def get_hf_file_metadata(
484
+ url,
485
+ token: nil,
486
+ proxies: nil,
487
+ timeout: DEFAULT_REQUEST_TIMEOUT,
488
+ library_name: nil,
489
+ library_version: nil,
490
+ user_agent: nil,
491
+ headers: nil
492
+ )
493
+ headers =
494
+ build_hf_headers(
495
+ token: token,
496
+ library_name: library_name,
497
+ library_version: library_version,
498
+ user_agent: user_agent,
499
+ headers: headers
500
+ )
501
+ headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file
502
+
503
+ # Retrieve metadata
504
+ r =
505
+ _request_wrapper(
506
+ "HEAD",
507
+ url,
508
+ headers: headers,
509
+ allow_redirects: false,
510
+ follow_relative_redirects: true,
511
+ proxies: proxies,
512
+ timeout: timeout
513
+ )
514
+ hf_raise_for_status(r)
515
+
516
+ # Return
517
+ HfFileMetadata.new(
518
+ commit_hash: r[HUGGINGFACE_HEADER_X_REPO_COMMIT],
519
+ # We favor a custom header indicating the etag of the linked resource, and
520
+ # we fallback to the regular etag header.
521
+ etag: _normalize_etag(r[HUGGINGFACE_HEADER_X_LINKED_ETAG] || r["etag"]),
522
+ # Either from response headers (if redirected) or defaults to request url
523
+ # Do not use directly `url`, as `_request_wrapper` might have followed relative
524
+ # redirects.
525
+ location: r["location"] || r.uri.to_s,
526
+ size: _int_or_none(r[HUGGINGFACE_HEADER_X_LINKED_SIZE] || r["content-length"])
527
+ )
528
+ end
529
+
530
+ def _get_metadata_or_catch_error(
531
+ repo_id:,
532
+ filename:,
533
+ repo_type:,
534
+ revision:,
535
+ endpoint:,
536
+ proxies:,
537
+ etag_timeout:,
538
+ headers:, # mutated inplace!
539
+ token:,
540
+ local_files_only:,
541
+ relative_filename: nil, # only used to store `.no_exists` in cache
542
+ storage_folder: nil # only used to store `.no_exists` in cache
543
+ )
544
+ if local_files_only
545
+ return [
546
+ nil,
547
+ nil,
548
+ nil,
549
+ nil,
550
+ OfflineModeIsEnabled.new(
551
+ "Cannot access file since 'local_files_only: true' as been set. (repo_id: #{repo_id}, repo_type: #{repo_type}, revision: #{revision}, filename: #{filename})"
552
+ )
553
+ ]
554
+ end
555
+
556
+ url = hf_hub_url(repo_id, filename, repo_type: repo_type, revision: revision, endpoint: endpoint)
557
+ url_to_download = url
558
+ etag = nil
559
+ commit_hash = nil
560
+ expected_size = nil
561
+ head_error_call = nil
562
+
563
+ if !local_files_only
564
+ metadata = nil
565
+ begin
566
+ metadata =
567
+ get_hf_file_metadata(
568
+ url,
569
+ proxies: proxies,
570
+ timeout: etag_timeout,
571
+ headers: headers,
572
+ token: token
573
+ )
574
+ rescue => e
575
+ raise e
576
+ raise Todo
577
+ end
578
+
579
+ # Commit hash must exist
580
+ commit_hash = metadata.commit_hash
581
+ if commit_hash.nil?
582
+ raise Todo
583
+ end
584
+
585
+ # Etag must exist
586
+ etag = metadata.etag
587
+ if etag.nil?
588
+ raise Todo
589
+ end
590
+
591
+ # Expected (uncompressed) size
592
+ expected_size = metadata.size
593
+ if expected_size.nil?
594
+ raise Todo
595
+ end
596
+
597
+ if metadata.location != url
598
+ url_to_download = metadata.location
599
+ if netloc(URI.parse(url)) != netloc(URI.parse(metadata.location))
600
+ # Remove authorization header when downloading a LFS blob
601
+ headers.delete("authorization")
602
+ end
603
+ end
604
+ end
605
+
606
+ if !(local_files_only || !etag.nil? || !head_call_error.nil?)
607
+ raise "etag is empty due to uncovered problems"
608
+ end
609
+
610
+ [url_to_download, etag, commit_hash, expected_size, head_error_call]
611
+ end
612
+
613
+ def _raise_on_head_call_error(head_call_error, force_download, local_files_only)
614
+ # No head call => we cannot force download.
615
+ if force_download
616
+ if local_files_only
617
+ raise ArgumentError, "Cannot pass 'force_download: true' and 'local_files_only: true' at the same time."
618
+ elsif head_call_error.is_a?(OfflineModeIsEnabled)
619
+ raise ArgumentError, "Cannot pass 'force_download: true' when offline mode is enabled."
620
+ else
621
+ raise ArgumentError, "Force download failed due to the above error."
622
+ end
623
+ end
624
+
625
+ # If we couldn't find an appropriate file on disk, raise an error.
626
+ # If files cannot be found and local_files_only=True,
627
+ # the models might've been found if local_files_only=False
628
+ # Notify the user about that
629
+ if local_files_only
630
+ raise LocalEntryNotFoundError,
631
+ "Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable" +
632
+ " hf.co look-ups and downloads online, set 'local_files_only' to false."
633
+ elsif head_call_error.is_a?(RepositoryNotFoundError) || head_call_error.is_a?(GatedRepoError)
634
+ # Repo not found or gated => let's raise the actual error
635
+ raise head_call_error
636
+ else
637
+ # Otherwise: most likely a connection issue or Hub downtime => let's warn the user
638
+ raise LocalEntryNotFoundError,
639
+ "An error happened while trying to locate the file on the Hub and we cannot find the requested files" +
640
+ " in the local cache. Please check your connection and try again or make sure your Internet connection" +
641
+ " is on."
642
+ end
643
+ end
644
+
645
+ def _download_to_tmp_and_move(
646
+ incomplete_path:,
647
+ destination_path:,
648
+ url_to_download:,
649
+ proxies:,
650
+ headers:,
651
+ expected_size:,
652
+ filename:,
653
+ force_download:
654
+ )
655
+ if destination_path.exist? && !force_download
656
+ # Do nothing if already exists (except if force_download=True)
657
+ return
658
+ end
659
+
660
+ if incomplete_path.exist? && (force_download || (HF_HUB_ENABLE_HF_TRANSFER && !proxies))
661
+ # By default, we will try to resume the download if possible.
662
+ # However, if the user has set `force_download=True` or if `hf_transfer` is enabled, then we should
663
+ # not resume the download => delete the incomplete file.
664
+ message = "Removing incomplete file '#{incomplete_path}'"
665
+ if force_download
666
+ message += " (force_download: true)"
667
+ elsif HF_HUB_ENABLE_HF_TRANSFER && !proxies
668
+ message += " (hf_transfer: true)"
669
+ end
670
+ Transformers.logger.info(message)
671
+ incomplete_path.unlink #(missing_ok=True)
672
+ end
673
+
674
+ incomplete_path.open("ab") do |f|
675
+ f.seek(0, IO::SEEK_END)
676
+ resume_size = f.tell
677
+ message = "Downloading '#{filename}' to '#{incomplete_path}'"
678
+ if resume_size > 0 && !expected_size.nil?
679
+ message += " (resume from #{resume_size}/#{expected_size})"
680
+ end
681
+ Transformers.logger.info(message)
682
+
683
+ if !expected_size.nil? # might be None if HTTP header not set correctly
684
+ # Check disk space in both tmp and destination path
685
+ _check_disk_space(expected_size, incomplete_path.parent)
686
+ _check_disk_space(expected_size, destination_path.parent)
687
+ end
688
+
689
+ http_get(
690
+ url_to_download,
691
+ f,
692
+ proxies: proxies,
693
+ resume_size: resume_size,
694
+ headers: headers,
695
+ expected_size: expected_size,
696
+ )
697
+ end
698
+
699
+ Transformers.logger.info("Download complete. Moving file to #{destination_path}")
700
+ _chmod_and_move(incomplete_path, destination_path)
701
+ end
702
+
703
+ def _int_or_none(value)
704
+ value&.to_i
705
+ end
706
+
707
+ def _chmod_and_move(src, dst)
708
+ tmp_file = dst.parent.parent / "tmp_#{SecureRandom.uuid}"
709
+ begin
710
+ FileUtils.touch(tmp_file)
711
+ cache_dir_mode = Pathname.new(tmp_file).stat.mode
712
+ src.chmod(cache_dir_mode)
713
+ ensure
714
+ begin
715
+ tmp_file.unlink
716
+ rescue Errno::ENOENT
717
+ # fails if `tmp_file.touch()` failed => do nothing
718
+ # See https://github.com/huggingface/huggingface_hub/issues/2359
719
+ end
720
+ end
721
+
722
+ FileUtils.move(src.to_s, dst.to_s)
723
+ end
724
+
725
+ def _get_pointer_path(storage_folder, revision, relative_filename)
726
+ snapshot_path = File.join(storage_folder, "snapshots")
727
+ pointer_path = File.join(snapshot_path, revision, relative_filename)
728
+ if !parents(Pathname.new(File.absolute_path(pointer_path))).include?(Pathname.new(File.absolute_path(snapshot_path)))
729
+ raise ArgumentError,
730
+ "Invalid pointer path: cannot create pointer path in snapshot folder if" +
731
+ " `storage_folder: #{storage_folder.inspect}`, `revision: #{revision.inspect}` and" +
732
+ " `relative_filename: #{relative_filename.inspect}`."
733
+ end
734
+ pointer_path
735
+ end
736
+
737
+ # additional methods
738
+
739
+ def netloc(uri)
740
+ [uri.host, uri.port].compact.join(":")
741
+ end
742
+
743
+ def parents(path)
744
+ parents = []
745
+ 100.times do
746
+ if path == path.parent
747
+ break
748
+ end
749
+ path = path.parent
750
+ parents << path
751
+ end
752
+ parents
753
+ end
754
+
755
+ def display_progress(filename, width, size, expected_size)
756
+ bar_width = width - (filename.length + 3)
757
+ progress = size / expected_size.to_f
758
+ done = (progress * bar_width).round
759
+ not_done = bar_width - done
760
+ "#{filename} |#{"█" * done}#{" " * not_done}|"
761
+ end
762
+ end
763
+ end
764
+ end